feat: merge develop (#123)
* Support hybrid vector retrieval * Enable figures and table reading in Azure DI * Retrieve with multi-modal * Fix mixing up table * Add txt loader * Add Anthropic Chat * Raising error when retrieving help file * Allow same filename for different people if private is True * Allow declaring extra LLM vendors * Show chunks on the File page * Allow elasticsearch to get more docs * Fix Cohere response (#86) * Fix Cohere response * Remove Adobe pdfservice from dependency kotaemon doesn't rely more pdfservice for its core functionality, and pdfservice uses very out-dated dependency that causes conflict. --------- Co-authored-by: trducng <trungduc1992@gmail.com> * Add confidence score (#87) * Save question answering data as a log file * Save the original information besides the rewritten info * Export Cohere relevance score as confidence score * Fix style check * Upgrade the confidence score appearance (#90) * Highlight the relevance score * Round relevance score. Get key from config instead of env * Cohere return all scores * Display relevance score for image * Remove columns and rows in Excel loader which contains all NaN (#91) * remove columns and rows which contains all NaN * back to multiple joiner options * Fix style --------- Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local> Co-authored-by: trducng <trungduc1992@gmail.com> * Track retriever state * Bump llama-index version 0.10 * feat/save-azuredi-mhtml-to-markdown (#93) * feat/save-azuredi-mhtml-to-markdown * fix: replace os.path to pathlib change theflow.settings * refactor: base on pre-commit * chore: move the func of saving content markdown above removed_spans --------- Co-authored-by: jacky0218 <jacky0218@github.com> * fix: losing first chunk (#94) * fix: losing first chunk. * fix: update the method of preventing losing chunks --------- Co-authored-by: jacky0218 <jacky0218@github.com> * fix: adding the base64 image in markdown (#95) * feat: more chunk info on UI * fix: error when reindexing files * refactor: allow more information exception trace when using gpt4v * feat: add excel reader that treats each worksheet as a document * Persist loader information when indexing file * feat: allow hiding unneeded setting panels * feat: allow specific timezone when creating conversation * feat: add more confidence score (#96) * Allow a list of rerankers * Export llm reranking score instead of filter with boolean * Get logprobs from LLMs * Rename cohere reranking score * Call 2 rerankers at once * Run QA pipeline for each chunk to get qa_score * Display more relevance scores * Define another LLMScoring instead of editing the original one * Export logprobs instead of probs * Call LLMScoring * Get qa_score only in the final answer * feat: replace text length with token in file list * ui: show index name instead of id in the settings * feat(ai): restrict the vision temperature * fix(ui): remove the misleading message about non-retrieved evidences * feat(ui): show the reasoning name and description in the reasoning setting page * feat(ui): show version on the main windows * feat(ui): show default llm name in the setting page * fix(conf): append the result of doc in llm_scoring (#97) * fix: constraint maximum number of images * feat(ui): allow filter file by name in file list page * Fix exceeding token length error for OpenAI embeddings by chunking then averaging (#99) * Average embeddings in case the text exceeds max size * Add docstring * fix: Allow empty string when calling embedding * fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98) * Round when displaying not by default * Add LLMTrulens reranking model * Use llmtrulensscoring in pipeline * fix: update UI display for trulen score --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * feat: add question decomposition & few-shot rewrite pipeline (#89) * Create few-shot query-rewriting. Run and display the result in info_panel * Fix style check * Put the functions to separate modules * Add zero-shot question decomposition * Fix fewshot rewriting * Add default few-shot examples * Fix decompose question * Fix importing rewriting pipelines * fix: update decompose logic in fullQA pipeline --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * fix: add encoding utf-8 when save temporal markdown in vectorIndex (#101) * fix: improve retrieval pipeline and relevant score display (#102) * fix: improve retrieval pipeline by extending first round top_k with multiplier * fix: minor fix * feat: improve UI default settings and add quick switch option for pipeline * fix: improve agent logics (#103) * fix: improve agent progres display * fix: update retrieval logic * fix: UI display * fix: less verbose debug log * feat: add warning message for low confidence * fix: LLM scoring enabled by default * fix: minor update logics * fix: hotfix image citation * feat: update docx loader for handle merged table cells + handle zip file upload (#104) * feat: update docx loader for handle merged table cells * feat: handle zip file * refactor: pre-commit * fix: escape text in download UI * feat: optimize vector store query db (#105) * feat: optimize vector store query db * feat: add file_id to chroma metadatas * feat: remove unnecessary logs and update migrate script * feat: iterate through file index * fix: remove unused code --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * fix: add openai embedidng exponential back-off * fix: update import download_loader * refactor: codespell * fix: update some default settings * fix: update installation instruction * fix: default chunk length in simple QA * feat: add share converstation feature and enable retrieval history (#108) * feat: add share converstation feature and enable retrieval history * fix: update share conversation UI --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * fix: allow exponential backoff for failed OCR call (#109) * fix: update default prompt when no retrieval is used * fix: create embedding for long image chunks * fix: add exception handling for additional table retriever * fix: clean conversation & file selection UI * fix: elastic search with empty doc_ids * feat: add thumbnail PDF reader for quick multimodal QA * feat: add thumbnail handling logic in indexing * fix: UI text update * fix: PDF thumb loader page number logic * feat: add quick indexing pipeline and update UI * feat: add conv name suggestion * fix: minor UI change * feat: citation in thread * fix: add conv name suggestion in regen * chore: add assets for usage doc * chore: update usage doc * feat: pdf viewer (#110) * feat: update pdfviewer * feat: update missing files * fix: update rendering logic of infor panel * fix: improve thumbnail retrieval logic * fix: update PDF evidence rendering logic * fix: remove pdfjs built dist * fix: reduce thumbnail evidence count * chore: update gitignore * fix: add js event on chat msg select * fix: update css for viewer * fix: add env var for PDFJS prebuilt * fix: move language setting to reasoning utils --------- Co-authored-by: phv2312 <kat87yb@gmail.com> Co-authored-by: trducng <trungduc1992@gmail.com> * feat: graph rag (#116) * fix: reload server when add/delete index * fix: rework indexing pipeline to be able to disable vectorstore and splitter if needed * feat: add graphRAG index with plot view * fix: update requirement for graphRAG and lighten unnecessary packages * feat: add knowledge network index (#118) * feat: add Knowledge Network index * fix: update reader mode setting for knet * fix: update init knet * fix: update collection name to index pipeline * fix: missing req --------- Co-authored-by: jeff52415 <jeff.yang@cinnamon.is> * fix: update info panel return for graphrag * fix: retriever setting graphrag * feat: local llm settings (#122) * feat: expose context length as reasoning setting to better fit local models * fix: update context length setting for agents * fix: rework threadpool llm call * fix: fix improve indexing logic * fix: fix improve UI * feat: add lancedb * fix: improve lancedb logic * feat: add lancedb vectorstore * fix: lighten requirement * fix: improve lanceDB vs * fix: improve UI * fix: openai retry * fix: update reqs * fix: update launch command * feat: update Dockerfile * feat: add plot history * fix: update default config * fix: remove verbose print * fix: update default setting * fix: update gradio plot return * fix: default gradio tmp * fix: improve lancedb docstore * fix: fix question decompose pipeline * feat: add multimodal reader in UI * fix: udpate docs * fix: update default settings & docker build * fix: update app startup * chore: update documentation * chore: update README * chore: update README --------- Co-authored-by: trducng <trungduc1992@gmail.com> * chore: update README * chore: update README --------- Co-authored-by: trducng <trungduc1992@gmail.com> Co-authored-by: cin-ace <ace@cinnamon.is> Co-authored-by: Linh Nguyen <70562198+linhnguyen-cinnamon@users.noreply.github.com> Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local> Co-authored-by: cin-jacky <101088014+jacky0218@users.noreply.github.com> Co-authored-by: jacky0218 <jacky0218@github.com> Co-authored-by: kan_cin <kan@cinnamon.is> Co-authored-by: phv2312 <kat87yb@gmail.com> Co-authored-by: jeff52415 <jeff.yang@cinnamon.is>
13
.dockerignore
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
.github/
|
||||||
|
.git/
|
||||||
|
.mypy_cache/
|
||||||
|
__pycache__/
|
||||||
|
ktem_app_data/
|
||||||
|
env/
|
||||||
|
.pre-commit-config.yaml
|
||||||
|
.commitlintrc
|
||||||
|
.gitignore
|
||||||
|
.gitattributes
|
||||||
|
README.md
|
||||||
|
*.zip
|
||||||
|
*.sh
|
25
.env
|
@ -1,8 +1,8 @@
|
||||||
# settings for OpenAI
|
# settings for OpenAI
|
||||||
OPENAI_API_BASE=https://api.openai.com/v1
|
OPENAI_API_BASE=https://api.openai.com/v1
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=openai_key
|
||||||
OPENAI_CHAT_MODEL=gpt-3.5-turbo
|
OPENAI_CHAT_MODEL=gpt-4o
|
||||||
OPENAI_EMBEDDINGS_MODEL=text-embedding-ada-002
|
OPENAI_EMBEDDINGS_MODEL=text-embedding-3-small
|
||||||
|
|
||||||
# settings for Azure OpenAI
|
# settings for Azure OpenAI
|
||||||
AZURE_OPENAI_ENDPOINT=
|
AZURE_OPENAI_ENDPOINT=
|
||||||
|
@ -15,4 +15,21 @@ AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT=text-embedding-ada-002
|
||||||
COHERE_API_KEY=
|
COHERE_API_KEY=
|
||||||
|
|
||||||
# settings for local models
|
# settings for local models
|
||||||
LOCAL_MODEL=
|
LOCAL_MODEL=llama3.1:8b
|
||||||
|
LOCAL_MODEL_EMBEDDINGS=nomic-embed-text
|
||||||
|
|
||||||
|
# settings for GraphRAG
|
||||||
|
GRAPHRAG_API_KEY=openai_key
|
||||||
|
GRAPHRAG_LLM_MODEL=gpt-4o-mini
|
||||||
|
GRAPHRAG_EMBEDDING_MODEL=text-embedding-3-small
|
||||||
|
|
||||||
|
# settings for Azure DI
|
||||||
|
AZURE_DI_ENDPOINT=
|
||||||
|
AZURE_DI_CREDENTIAL=
|
||||||
|
|
||||||
|
# settings for Adobe API
|
||||||
|
PDF_SERVICES_CLIENT_ID=
|
||||||
|
PDF_SERVICES_CLIENT_SECRET=
|
||||||
|
|
||||||
|
# settings for PDF.js
|
||||||
|
PDFJS_VERSION_DIST="pdfjs-4.0.379-dist"
|
||||||
|
|
1
.gitignore
vendored
|
@ -471,3 +471,4 @@ doc_env/
|
||||||
|
|
||||||
# application data
|
# application data
|
||||||
ktem_app_data/
|
ktem_app_data/
|
||||||
|
gradio_tmp/
|
||||||
|
|
37
Dockerfile
Normal file
|
@ -0,0 +1,37 @@
|
||||||
|
# syntax=docker/dockerfile:1.0.0-experimental
|
||||||
|
FROM python:3.10-slim as base_image
|
||||||
|
|
||||||
|
# for additional file parsers
|
||||||
|
|
||||||
|
# tesseract-ocr \
|
||||||
|
# tesseract-ocr-jpn \
|
||||||
|
# libsm6 \
|
||||||
|
# libxext6 \
|
||||||
|
# ffmpeg \
|
||||||
|
|
||||||
|
RUN apt update -qqy \
|
||||||
|
&& apt install -y \
|
||||||
|
ssh git \
|
||||||
|
gcc g++ \
|
||||||
|
poppler-utils \
|
||||||
|
libpoppler-dev \
|
||||||
|
&& \
|
||||||
|
apt-get clean && \
|
||||||
|
apt-get autoremove
|
||||||
|
|
||||||
|
ENV PYTHONDONTWRITEBYTECODE=1
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV PYTHONIOENCODING=UTF-8
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
|
||||||
|
FROM base_image as dev
|
||||||
|
|
||||||
|
COPY . /app
|
||||||
|
RUN --mount=type=ssh pip install -e "libs/kotaemon[all]"
|
||||||
|
RUN --mount=type=ssh pip install -e "libs/ktem"
|
||||||
|
RUN pip install graphrag future
|
||||||
|
RUN pip install "pdfservices-sdk@git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements"
|
||||||
|
|
||||||
|
ENTRYPOINT ["gradio", "app.py"]
|
206
README.md
|
@ -1,12 +1,12 @@
|
||||||
# kotaemon
|
# kotaemon
|
||||||
|
|
||||||
An open-source tool for chatting with your documents. Built with both end users and
|
An open-source clean & customizable RAG UI for chatting with your documents. Built with both end users and
|
||||||
developers in mind.
|
developers in mind.
|
||||||
|
|
||||||
https://github.com/Cinnamon/kotaemon/assets/25688648/815ecf68-3a02-4914-a0dd-3f8ec7e75cd9
|

|
||||||
|
|
||||||
[Source Code](https://github.com/Cinnamon/kotaemon) |
|
[Live Demo](https://huggingface.co/spaces/taprosoft/kotaemon) |
|
||||||
[Live Demo](https://huggingface.co/spaces/cin-model/kotaemon-public)
|
[Source Code](https://github.com/Cinnamon/kotaemon)
|
||||||
|
|
||||||
[User Guide](https://cinnamon.github.io/kotaemon/) |
|
[User Guide](https://cinnamon.github.io/kotaemon/) |
|
||||||
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
|
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
|
||||||
|
@ -14,20 +14,23 @@ https://github.com/Cinnamon/kotaemon/assets/25688648/815ecf68-3a02-4914-a0dd-3f8
|
||||||
|
|
||||||
[](https://www.python.org/downloads/release/python-31013/)
|
[](https://www.python.org/downloads/release/python-31013/)
|
||||||
[](https://github.com/psf/black)
|
[](https://github.com/psf/black)
|
||||||
|
<a href="https://hub.docker.com/r/taprosoft/kotaemon" target="_blank">
|
||||||
|
<img src="https://img.shields.io/badge/docker_pull-kotaemon:v1.0-brightgreen" alt="docker pull taprosoft/kotaemon:v1.0"></a>
|
||||||
[](https://codeium.com)
|
[](https://codeium.com)
|
||||||
|
|
||||||
This project would like to appeal to both end users who want to do QA on their
|
## Introduction
|
||||||
documents and developers who want to build their own QA pipeline.
|
|
||||||
|
This project serves as a functional RAG UI for both end users who want to do QA on their
|
||||||
|
documents and developers who want to build their own RAG pipeline.
|
||||||
|
|
||||||
- For end users:
|
- For end users:
|
||||||
- A local Question Answering UI for RAG-based QA.
|
- A clean & minimalistic UI for RAG-based QA.
|
||||||
- Supports LLM API providers (OpenAI, AzureOpenAI, Cohere, etc) and local LLMs
|
- Supports LLM API providers (OpenAI, AzureOpenAI, Cohere, etc) and local LLMs
|
||||||
(currently only GGUF format is supported via `llama-cpp-python`).
|
(via `ollama` and `llama-cpp-python`).
|
||||||
- Easy installation scripts, no environment setup required.
|
- Easy installation scripts.
|
||||||
- For developers:
|
- For developers:
|
||||||
- A framework for building your own RAG-based QA pipeline.
|
- A framework for building your own RAG-based document QA pipeline.
|
||||||
- See your RAG pipeline in action with the provided UI (built with Gradio).
|
- Customize and see your RAG pipeline in action with the provided UI (built with Gradio).
|
||||||
- Share your pipeline so that others can use it.
|
|
||||||
|
|
||||||
```yml
|
```yml
|
||||||
+----------------------------------------------------------------------------+
|
+----------------------------------------------------------------------------+
|
||||||
|
@ -45,78 +48,128 @@ documents and developers who want to build their own QA pipeline.
|
||||||
```
|
```
|
||||||
|
|
||||||
This repository is under active development. Feedback, issues, and PRs are highly
|
This repository is under active development. Feedback, issues, and PRs are highly
|
||||||
appreciated. Your input is valuable as it helps us persuade our business guys to support
|
appreciated.
|
||||||
open source.
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
- **Host your own document QA (RAG) web-UI**. Support multi-user login, organize your files in private / public collections, collaborate and share your favorite chat with others.
|
||||||
|
|
||||||
|
- **Organize your LLM & Embedding models**. Support both local LLMs & popular API providers (OpenAI, Azure, Ollama, Groq).
|
||||||
|
|
||||||
|
- **Hybrid RAG pipeline**. Sane default RAG pipeline with hybrid (full-text & vector) retriever + re-ranking to ensure best retrieval quality.
|
||||||
|
|
||||||
|
- **Multi-modal QA support**. Perform Question Answering on multiple documents with figures & tables support. Support multi-modal document parsing (selectable options on UI).
|
||||||
|
|
||||||
|
- **Advance citations with document preview**. By default the system will provide detailed citations to ensure the correctness of LLM answers. View your citations (incl. relevant score) directly in the _in-browser PDF viewer_ with highlights. Warning when retrieval pipeline return low relevant articles.
|
||||||
|
|
||||||
|
- **Support complex reasoning methods**. Use question decomposition to answer your complex / multi-hop question. Support agent-based reasoning with ReAct, ReWOO and other agents.
|
||||||
|
|
||||||
|
- **Configurable settings UI**. You can adjust most important aspects of retrieval & generation process on the UI (incl. prompts).
|
||||||
|
|
||||||
|
- **Extensible**. Being built on Gradio, you are free to customize / add any UI elements as you like. Also, we aim to support multiple strategies for document indexing & retrieval. `GraphRAG` indexing pipeline is provided as an example.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### For end users
|
### For end users
|
||||||
|
|
||||||
This document is intended for developers. If you just want to install and use the app as
|
This document is intended for developers. If you just want to install and use the app as
|
||||||
it, please follow the [User Guide](https://cinnamon.github.io/kotaemon/).
|
it is, please follow the non-technical [User Guide](https://cinnamon.github.io/kotaemon/) (WIP).
|
||||||
|
|
||||||
### For developers
|
### For developers
|
||||||
|
|
||||||
```shell
|
#### With Docker (recommended)
|
||||||
# Create a environment
|
|
||||||
python -m venv kotaemon-env
|
|
||||||
|
|
||||||
# Activate the environment
|
- Use this command to launch the server
|
||||||
source kotaemon-env/bin/activate
|
|
||||||
|
|
||||||
# Install the package
|
```
|
||||||
pip install git+https://github.com/Cinnamon/kotaemon.git
|
docker run \
|
||||||
|
-e GRADIO_SERVER_NAME=0.0.0.0 \
|
||||||
|
-e GRADIO_SERVER_PORT=7860 \
|
||||||
|
-p 7860:7860 -it --rm \
|
||||||
|
taprosoft/kotaemon:v1.0
|
||||||
```
|
```
|
||||||
|
|
||||||
### For Contributors
|
Navigate to `http://localhost:7860/` to access the web UI.
|
||||||
|
|
||||||
|
#### Without Docker
|
||||||
|
|
||||||
|
- Clone and install required packages on a fresh python environment.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
# Clone the repo
|
# optional (setup env)
|
||||||
git clone git@github.com:Cinnamon/kotaemon.git
|
conda create -n kotaemon python=3.10
|
||||||
|
conda activate kotaemon
|
||||||
|
|
||||||
# Create a environment
|
# clone this repo
|
||||||
python -m venv kotaemon-env
|
git clone https://github.com/Cinnamon/kotaemon
|
||||||
|
|
||||||
# Activate the environment
|
|
||||||
source kotaemon-env/bin/activate
|
|
||||||
cd kotaemon
|
cd kotaemon
|
||||||
|
|
||||||
# Install the package in editable mode
|
|
||||||
pip install -e "libs/kotaemon[all]"
|
pip install -e "libs/kotaemon[all]"
|
||||||
pip install -e "libs/ktem"
|
pip install -e "libs/ktem"
|
||||||
pip install -e "."
|
|
||||||
|
|
||||||
# Setup pre-commit
|
|
||||||
pre-commit install
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Creating your application
|
- View and edit your environment variables (API keys, end-points) in `.env`.
|
||||||
|
|
||||||
In order to create your own application, you need to prepare these files:
|
- (Optional) To enable in-browser PDF_JS viewer, download [PDF_JS_DIST](https://github.com/mozilla/pdf.js/releases/download/v4.0.379/pdfjs-4.0.379-dist.zip) and extract it to `libs/ktem/ktem/assets/prebuilt`
|
||||||
|
|
||||||
|
<img src="docs/images/pdf-viewer-setup.png" alt="pdf-setup" width="300">
|
||||||
|
|
||||||
|
- Start the web server:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
The app will be automatically launched in your browser.
|
||||||
|
|
||||||
|
Default username / password are: `admin` / `admin`. You can setup additional users directly on the UI.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Customize your application
|
||||||
|
|
||||||
|
By default, all application data are stored in `./ktem_app_data` folder. You can backup or copy this folder to move your installation to a new machine.
|
||||||
|
|
||||||
|
For advance users or specific use-cases, you can customize those files:
|
||||||
|
|
||||||
- `flowsettings.py`
|
- `flowsettings.py`
|
||||||
- `app.py`
|
- `.env`
|
||||||
- `.env` (Optional)
|
|
||||||
|
|
||||||
### `flowsettings.py`
|
### `flowsettings.py`
|
||||||
|
|
||||||
This file contains the configuration of your application. You can use the example
|
This file contains the configuration of your application. You can use the example
|
||||||
[here](https://github.com/Cinnamon/kotaemon/blob/main/libs/ktem/flowsettings.py) as the
|
[here](flowsettings.py) as the
|
||||||
starting point.
|
starting point.
|
||||||
|
|
||||||
### `app.py`
|
<details>
|
||||||
|
|
||||||
This file is where you create your Gradio app object. This can be as simple as:
|
<summary>Notable settings</summary>
|
||||||
|
|
||||||
```python
|
```
|
||||||
from ktem.main import App
|
# setup your preferred document store (with full-text search capabilities)
|
||||||
|
KH_DOCSTORE=(Elasticsearch | LanceDB | SimpleFileDocumentStore)
|
||||||
|
|
||||||
app = App()
|
# setup your preferred vectorstore (for vector-based search)
|
||||||
demo = app.make()
|
KH_VECTORSTORE=(ChromaDB | LanceDB
|
||||||
demo.launch()
|
|
||||||
|
# Enable / disable multimodal QA
|
||||||
|
KH_REASONINGS_USE_MULTIMODAL=True
|
||||||
|
|
||||||
|
# Setup your new reasoning pipeline or modify existing one.
|
||||||
|
KH_REASONINGS = [
|
||||||
|
"ktem.reasoning.simple.FullQAPipeline",
|
||||||
|
"ktem.reasoning.simple.FullDecomposeQAPipeline",
|
||||||
|
"ktem.reasoning.react.ReactAgentPipeline",
|
||||||
|
"ktem.reasoning.rewoo.RewooAgentPipeline",
|
||||||
|
]
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
### `.env` (Optional)
|
</details>
|
||||||
|
|
||||||
|
### `.env`
|
||||||
|
|
||||||
This file provides another way to configure your models and credentials.
|
This file provides another way to configure your models and credentials.
|
||||||
|
|
||||||
|
@ -159,18 +212,22 @@ AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT=text-embedding-ada-002
|
||||||
|
|
||||||
#### Local models
|
#### Local models
|
||||||
|
|
||||||
- Pros:
|
##### Using ollama OpenAI compatible server
|
||||||
- Privacy. Your documents will be stored and process locally.
|
|
||||||
- Choices. There are a wide range of LLMs in terms of size, domain, language to choose
|
|
||||||
from.
|
|
||||||
- Cost. It's free.
|
|
||||||
- Cons:
|
|
||||||
- Quality. Local models are much smaller and thus have lower generative quality than
|
|
||||||
paid APIs.
|
|
||||||
- Speed. Local models are deployed using your machine so the processing speed is
|
|
||||||
limited by your hardware.
|
|
||||||
|
|
||||||
##### Find and download a LLM
|
Install [ollama](https://github.com/ollama/ollama) and start the application.
|
||||||
|
|
||||||
|
Pull your model (e.g):
|
||||||
|
|
||||||
|
```
|
||||||
|
ollama pull llama3.1:8b
|
||||||
|
ollama pull nomic-embed-text
|
||||||
|
```
|
||||||
|
|
||||||
|
Set the model names on web UI and make it as default.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
##### Using GGUF with llama-cpp-python
|
||||||
|
|
||||||
You can search and download a LLM to be ran locally from the [Hugging Face
|
You can search and download a LLM to be ran locally from the [Hugging Face
|
||||||
Hub](https://huggingface.co/models). Currently, these model formats are supported:
|
Hub](https://huggingface.co/models). Currently, these model formats are supported:
|
||||||
|
@ -187,33 +244,26 @@ Here are some recommendations and their size in memory:
|
||||||
- [Qwen1.5-1.8B-Chat-GGUF](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf?download=true):
|
- [Qwen1.5-1.8B-Chat-GGUF](https://huggingface.co/Qwen/Qwen1.5-1.8B-Chat-GGUF/resolve/main/qwen1_5-1_8b-chat-q8_0.gguf?download=true):
|
||||||
around 2 GB
|
around 2 GB
|
||||||
|
|
||||||
##### Enable local models
|
Add a new LlamaCpp model with the provided model name on the web uI.
|
||||||
|
|
||||||
To add a local model to the model pool, set the `LOCAL_MODEL` variable in the `.env`
|
|
||||||
file to the path of the model file.
|
|
||||||
|
|
||||||
```shell
|
|
||||||
LOCAL_MODEL=<full path to your model file>
|
|
||||||
```
|
|
||||||
|
|
||||||
Here is how to get the full path of your model file:
|
|
||||||
|
|
||||||
- On Windows 11: right click the file and select `Copy as Path`.
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## Start your application
|
## Adding your own RAG pipeline
|
||||||
|
|
||||||
Simply run the following command:
|
#### Custom reasoning pipeline
|
||||||
|
|
||||||
```shell
|
First, check the default pipeline implementation in
|
||||||
python app.py
|
[here](libs/ktem/ktem/reasoning/simple.py). You can make quick adjustment to how the default QA pipeline work.
|
||||||
```
|
|
||||||
|
|
||||||
The app will be automatically launched in your browser.
|
Next, if you feel comfortable adding new pipeline, add new `.py` implementation in `libs/ktem/ktem/reasoning/` and later include it in `flowssettings` to enable it on the UI.
|
||||||
|
|
||||||

|
#### Custom indexing pipeline
|
||||||
|
|
||||||
## Customize your application
|
Check sample implementation in `libs/ktem/ktem/index/file/graph`
|
||||||
|
|
||||||
|
(more instruction WIP).
|
||||||
|
|
||||||
|
## Developer guide
|
||||||
|
|
||||||
Please refer to the [Developer Guide](https://cinnamon.github.io/kotaemon/development/)
|
Please refer to the [Developer Guide](https://cinnamon.github.io/kotaemon/development/)
|
||||||
for more details.
|
for more details.
|
||||||
|
|
23
app.py
|
@ -1,5 +1,24 @@
|
||||||
from ktem.main import App
|
import os
|
||||||
|
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
KH_APP_DATA_DIR = getattr(flowsettings, "KH_APP_DATA_DIR", ".")
|
||||||
|
GRADIO_TEMP_DIR = os.getenv("GRADIO_TEMP_DIR", None)
|
||||||
|
# override GRADIO_TEMP_DIR if it's not set
|
||||||
|
if GRADIO_TEMP_DIR is None:
|
||||||
|
GRADIO_TEMP_DIR = os.path.join(KH_APP_DATA_DIR, "gradio_tmp")
|
||||||
|
os.environ["GRADIO_TEMP_DIR"] = GRADIO_TEMP_DIR
|
||||||
|
|
||||||
|
|
||||||
|
from ktem.main import App # noqa
|
||||||
|
|
||||||
app = App()
|
app = App()
|
||||||
demo = app.make()
|
demo = app.make()
|
||||||
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
|
demo.queue().launch(
|
||||||
|
favicon_path=app._favicon,
|
||||||
|
inbrowser=True,
|
||||||
|
allowed_paths=[
|
||||||
|
"libs/ktem/ktem/assets",
|
||||||
|
GRADIO_TEMP_DIR,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
|
@ -9,3 +9,6 @@ developers in mind.
|
||||||
[User Guide](https://cinnamon.github.io/kotaemon/) |
|
[User Guide](https://cinnamon.github.io/kotaemon/) |
|
||||||
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
|
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
|
||||||
[Feedback](https://github.com/Cinnamon/kotaemon/issues)
|
[Feedback](https://github.com/Cinnamon/kotaemon/issues)
|
||||||
|
|
||||||
|
[Dark Mode](?__theme=dark) |
|
||||||
|
[Light Mode](?__theme=light)
|
||||||
|
|
BIN
docs/images/info-panel-scores.png
Normal file
After Width: | Height: | Size: 545 KiB |
BIN
docs/images/models.png
Normal file
After Width: | Height: | Size: 136 KiB |
BIN
docs/images/pdf-viewer-setup.png
Normal file
After Width: | Height: | Size: 38 KiB |
BIN
docs/images/preview-graph.png
Normal file
After Width: | Height: | Size: 288 KiB |
BIN
docs/images/preview.png
Normal file
After Width: | Height: | Size: 566 KiB |
|
@ -107,9 +107,9 @@ string rather than a string.
|
||||||
|
|
||||||
## Software infrastructure
|
## Software infrastructure
|
||||||
|
|
||||||
| Infra | Access | Schema | Ref |
|
| Infra | Access | Schema | Ref |
|
||||||
| ---------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
|
| ---------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------- |
|
||||||
| SQL table Source | self.\_Source | - id (int): id of the source (auto)<br>- name (str): the name of the file<br>- path (str): the path of the file<br>- size (int): the file size in bytes<br>- text_length (int): the number of characters in the file (default 0)<br>- date_created (datetime): the time the file is created (auto) | This is SQLALchemy ORM class. Can consult |
|
| SQL table Source | self.\_Source | - id (int): id of the source (auto)<br>- name (str): the name of the file<br>- path (str): the path of the file<br>- size (int): the file size in bytes<br>- note (dict): allow extra optional information about the file<br>- date_created (datetime): the time the file is created (auto) | This is SQLALchemy ORM class. Can consult |
|
||||||
| SQL table Index | self.\_Index | - id (int): id of the index entry (auto)<br>- source_id (int): the id of a file in the Source table<br>- target_id: the id of the segment in docstore or vector store<br>- relation_type (str): if the link is "document" or "vector" | This is SQLAlchemy ORM class |
|
| SQL table Index | self.\_Index | - id (int): id of the index entry (auto)<br>- source_id (int): the id of a file in the Source table<br>- target_id: the id of the segment in docstore or vector store<br>- relation_type (str): if the link is "document" or "vector" | This is SQLAlchemy ORM class |
|
||||||
| Vector store | self.\_VS | - self.\_VS.add: add the list of embeddings to the vector store (optionally associate metadata and ids)<br>- self.\_VS.delete: delete vector entries based on ids<br>- self.\_VS.query: get embeddings based on embeddings. | kotaemon > storages > vectorstores > BaseVectorStore |
|
| Vector store | self.\_VS | - self.\_VS.add: add the list of embeddings to the vector store (optionally associate metadata and ids)<br>- self.\_VS.delete: delete vector entries based on ids<br>- self.\_VS.query: get embeddings based on embeddings. | kotaemon > storages > vectorstores > BaseVectorStore |
|
||||||
| Doc store | self.\_DS | - self.\_DS.add: add the segments to document stores<br>- self.\_DS.get: get the segments based on id<br>- self.\_DS.get_all: get all segments<br>- self.\_DS.delete: delete segments based on id | kotaemon > storages > docstores > base > BaseDocumentStore |
|
| Doc store | self.\_DS | - self.\_DS.add: add the segments to document stores<br>- self.\_DS.get: get the segments based on id<br>- self.\_DS.get_all: get all segments<br>- self.\_DS.delete: delete segments based on id | kotaemon > storages > docstores > base > BaseDocumentStore |
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
# Basic Usage
|
|
||||||
|
|
||||||
## 1. Add your AI models
|
## 1. Add your AI models
|
||||||
|
|
||||||

|

|
||||||
|
@ -63,12 +61,15 @@ AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT=text-embedding-ada-002 # change to your deplo
|
||||||
|
|
||||||
### Local models
|
### Local models
|
||||||
|
|
||||||
- Pros:
|
Pros:
|
||||||
|
|
||||||
- Privacy. Your documents will be stored and process locally.
|
- Privacy. Your documents will be stored and process locally.
|
||||||
- Choices. There are a wide range of LLMs in terms of size, domain, language to choose
|
- Choices. There are a wide range of LLMs in terms of size, domain, language to choose
|
||||||
from.
|
from.
|
||||||
- Cost. It's free.
|
- Cost. It's free.
|
||||||
- Cons:
|
|
||||||
|
Cons:
|
||||||
|
|
||||||
- Quality. Local models are much smaller and thus have lower generative quality than
|
- Quality. Local models are much smaller and thus have lower generative quality than
|
||||||
paid APIs.
|
paid APIs.
|
||||||
- Speed. Local models are deployed using your machine so the processing speed is
|
- Speed. Local models are deployed using your machine so the processing speed is
|
||||||
|
@ -136,6 +137,21 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
|
||||||
files will be considered during chat.
|
files will be considered during chat.
|
||||||
2. Chat Panel
|
2. Chat Panel
|
||||||
- This is where you can chat with the chatbot.
|
- This is where you can chat with the chatbot.
|
||||||
3. Information panel
|
3. Information Panel
|
||||||
- Supporting information such as the retrieved evidence and reference will be
|
|
||||||
displayed here.
|

|
||||||
|
|
||||||
|
- Supporting information such as the retrieved evidence and reference will be
|
||||||
|
displayed here.
|
||||||
|
- Direct citation for the answer produced by the LLM is highlighted.
|
||||||
|
- The confidence score of the answer and relevant scores of evidences are displayed to quickly assess the quality of the answer and retrieved content.
|
||||||
|
|
||||||
|
- Meaning of the score displayed:
|
||||||
|
- **Answer confidence**: answer confidence level from the LLM model.
|
||||||
|
- **Relevance score**: overall relevant score between evidence and user question.
|
||||||
|
- **Vectorstore score**: relevant score from vector embedding similarity calculation (show `full-text search` if retrieved from full-text search DB).
|
||||||
|
- **LLM relevant score**: relevant score from LLM model (which judge relevancy between question and evidence using specific prompt).
|
||||||
|
- **Reranking score**: relevant score from Cohere [reranking model](https://cohere.com/rerank).
|
||||||
|
|
||||||
|
Generally, the score quality is `LLM relevant score` > `Reranking score` > `Vectorscore`.
|
||||||
|
By default, overall relevance score is taken directly from LLM relevant score. Evidences are sorted based on their overall relevance score and whether they have citation or not.
|
||||||
|
|
169
flowsettings.py
|
@ -15,7 +15,7 @@ this_dir = Path(this_file).parent
|
||||||
# change this if your app use a different name
|
# change this if your app use a different name
|
||||||
KH_PACKAGE_NAME = "kotaemon_app"
|
KH_PACKAGE_NAME = "kotaemon_app"
|
||||||
|
|
||||||
KH_APP_VERSION = os.environ.get("KH_APP_VERSION", None)
|
KH_APP_VERSION = config("KH_APP_VERSION", "local")
|
||||||
if not KH_APP_VERSION:
|
if not KH_APP_VERSION:
|
||||||
try:
|
try:
|
||||||
# Caution: This might produce the wrong version
|
# Caution: This might produce the wrong version
|
||||||
|
@ -33,8 +33,21 @@ KH_APP_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
KH_USER_DATA_DIR = KH_APP_DATA_DIR / "user_data"
|
KH_USER_DATA_DIR = KH_APP_DATA_DIR / "user_data"
|
||||||
KH_USER_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
KH_USER_DATA_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# doc directory
|
# markdown output directory
|
||||||
KH_DOC_DIR = this_dir / "docs"
|
KH_MARKDOWN_OUTPUT_DIR = KH_APP_DATA_DIR / "markdown_cache_dir"
|
||||||
|
KH_MARKDOWN_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# chunks output directory
|
||||||
|
KH_CHUNKS_OUTPUT_DIR = KH_APP_DATA_DIR / "chunks_cache_dir"
|
||||||
|
KH_CHUNKS_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# zip output directory
|
||||||
|
KH_ZIP_OUTPUT_DIR = KH_APP_DATA_DIR / "zip_cache_dir"
|
||||||
|
KH_ZIP_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# zip input directory
|
||||||
|
KH_ZIP_INPUT_DIR = KH_APP_DATA_DIR / "zip_cache_dir_in"
|
||||||
|
KH_ZIP_INPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# HF models can be big, let's store them in the app data directory so that it's easier
|
# HF models can be big, let's store them in the app data directory so that it's easier
|
||||||
# for users to manage their storage.
|
# for users to manage their storage.
|
||||||
|
@ -42,24 +55,30 @@ KH_DOC_DIR = this_dir / "docs"
|
||||||
os.environ["HF_HOME"] = str(KH_APP_DATA_DIR / "huggingface")
|
os.environ["HF_HOME"] = str(KH_APP_DATA_DIR / "huggingface")
|
||||||
os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
os.environ["HF_HUB_CACHE"] = str(KH_APP_DATA_DIR / "huggingface")
|
||||||
|
|
||||||
COHERE_API_KEY = config("COHERE_API_KEY", default="")
|
# doc directory
|
||||||
|
KH_DOC_DIR = this_dir / "docs"
|
||||||
|
|
||||||
KH_MODE = "dev"
|
KH_MODE = "dev"
|
||||||
KH_FEATURE_USER_MANAGEMENT = False
|
KH_FEATURE_USER_MANAGEMENT = True
|
||||||
|
KH_USER_CAN_SEE_PUBLIC = None
|
||||||
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
|
KH_FEATURE_USER_MANAGEMENT_ADMIN = str(
|
||||||
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
|
config("KH_FEATURE_USER_MANAGEMENT_ADMIN", default="admin")
|
||||||
)
|
)
|
||||||
KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
|
KH_FEATURE_USER_MANAGEMENT_PASSWORD = str(
|
||||||
config("KH_FEATURE_USER_MANAGEMENT_PASSWORD", default="XsdMbe8zKP8KdeE@")
|
config("KH_FEATURE_USER_MANAGEMENT_PASSWORD", default="admin")
|
||||||
)
|
)
|
||||||
KH_ENABLE_ALEMBIC = False
|
KH_ENABLE_ALEMBIC = False
|
||||||
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
|
KH_DATABASE = f"sqlite:///{KH_USER_DATA_DIR / 'sql.db'}"
|
||||||
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
|
KH_FILESTORAGE_PATH = str(KH_USER_DATA_DIR / "files")
|
||||||
|
|
||||||
KH_DOCSTORE = {
|
KH_DOCSTORE = {
|
||||||
"__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
# "__type__": "kotaemon.storages.ElasticsearchDocumentStore",
|
||||||
|
# "__type__": "kotaemon.storages.SimpleFileDocumentStore",
|
||||||
|
"__type__": "kotaemon.storages.LanceDBDocumentStore",
|
||||||
"path": str(KH_USER_DATA_DIR / "docstore"),
|
"path": str(KH_USER_DATA_DIR / "docstore"),
|
||||||
}
|
}
|
||||||
KH_VECTORSTORE = {
|
KH_VECTORSTORE = {
|
||||||
|
# "__type__": "kotaemon.storages.LanceDBVectorStore",
|
||||||
"__type__": "kotaemon.storages.ChromaVectorStore",
|
"__type__": "kotaemon.storages.ChromaVectorStore",
|
||||||
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
"path": str(KH_USER_DATA_DIR / "vectorstore"),
|
||||||
}
|
}
|
||||||
|
@ -83,8 +102,6 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
|
||||||
"timeout": 20,
|
"timeout": 20,
|
||||||
},
|
},
|
||||||
"default": False,
|
"default": False,
|
||||||
"accuracy": 5,
|
|
||||||
"cost": 5,
|
|
||||||
}
|
}
|
||||||
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
|
||||||
KH_EMBEDDINGS["azure"] = {
|
KH_EMBEDDINGS["azure"] = {
|
||||||
|
@ -110,71 +127,66 @@ if config("OPENAI_API_KEY", default=""):
|
||||||
"base_url": config("OPENAI_API_BASE", default="")
|
"base_url": config("OPENAI_API_BASE", default="")
|
||||||
or "https://api.openai.com/v1",
|
or "https://api.openai.com/v1",
|
||||||
"api_key": config("OPENAI_API_KEY", default=""),
|
"api_key": config("OPENAI_API_KEY", default=""),
|
||||||
"model": config("OPENAI_CHAT_MODEL", default="") or "gpt-3.5-turbo",
|
"model": config("OPENAI_CHAT_MODEL", default="gpt-3.5-turbo"),
|
||||||
|
"timeout": 20,
|
||||||
|
},
|
||||||
|
"default": True,
|
||||||
|
}
|
||||||
|
KH_EMBEDDINGS["openai"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||||
|
"base_url": config("OPENAI_API_BASE", default="https://api.openai.com/v1"),
|
||||||
|
"api_key": config("OPENAI_API_KEY", default=""),
|
||||||
|
"model": config(
|
||||||
|
"OPENAI_EMBEDDINGS_MODEL", default="text-embedding-ada-002"
|
||||||
|
),
|
||||||
"timeout": 10,
|
"timeout": 10,
|
||||||
},
|
"context_length": 8191,
|
||||||
"default": False,
|
|
||||||
}
|
|
||||||
if len(KH_EMBEDDINGS) < 1:
|
|
||||||
KH_EMBEDDINGS["openai"] = {
|
|
||||||
"spec": {
|
|
||||||
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
|
||||||
"base_url": config("OPENAI_API_BASE", default="")
|
|
||||||
or "https://api.openai.com/v1",
|
|
||||||
"api_key": config("OPENAI_API_KEY", default=""),
|
|
||||||
"model": config(
|
|
||||||
"OPENAI_EMBEDDINGS_MODEL", default="text-embedding-ada-002"
|
|
||||||
)
|
|
||||||
or "text-embedding-ada-002",
|
|
||||||
"timeout": 10,
|
|
||||||
},
|
|
||||||
"default": False,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config("LOCAL_MODEL", default=""):
|
|
||||||
KH_LLMS["local"] = {
|
|
||||||
"spec": {
|
|
||||||
"__type__": "kotaemon.llms.EndpointChatLLM",
|
|
||||||
"endpoint_url": "http://localhost:31415/v1/chat/completions",
|
|
||||||
},
|
|
||||||
"default": False,
|
|
||||||
"cost": 0,
|
|
||||||
}
|
|
||||||
if len(KH_EMBEDDINGS) < 1:
|
|
||||||
KH_EMBEDDINGS["local"] = {
|
|
||||||
"spec": {
|
|
||||||
"__type__": "kotaemon.embeddings.EndpointEmbeddings",
|
|
||||||
"endpoint_url": "http://localhost:31415/v1/embeddings",
|
|
||||||
},
|
|
||||||
"default": False,
|
|
||||||
"cost": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(KH_EMBEDDINGS) < 1:
|
|
||||||
KH_EMBEDDINGS["local-bge-base-en-v1.5"] = {
|
|
||||||
"spec": {
|
|
||||||
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
|
|
||||||
"model_name": "BAAI/bge-base-en-v1.5",
|
|
||||||
},
|
},
|
||||||
"default": True,
|
"default": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
KH_REASONINGS = ["ktem.reasoning.simple.FullQAPipeline"]
|
if config("LOCAL_MODEL", default=""):
|
||||||
|
KH_LLMS["ollama"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.llms.ChatOpenAI",
|
||||||
|
"base_url": "http://localhost:11434/v1/",
|
||||||
|
"model": config("LOCAL_MODEL", default="llama3.1:8b"),
|
||||||
|
},
|
||||||
|
"default": False,
|
||||||
|
}
|
||||||
|
KH_EMBEDDINGS["ollama"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.embeddings.OpenAIEmbeddings",
|
||||||
|
"base_url": "http://localhost:11434/v1/",
|
||||||
|
"model": config("LOCAL_MODEL_EMBEDDINGS", default="nomic-embed-text"),
|
||||||
|
},
|
||||||
|
"default": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
KH_EMBEDDINGS["local-bge-en"] = {
|
||||||
|
"spec": {
|
||||||
|
"__type__": "kotaemon.embeddings.FastEmbedEmbeddings",
|
||||||
|
"model_name": "BAAI/bge-base-en-v1.5",
|
||||||
|
},
|
||||||
|
"default": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
KH_REASONINGS = [
|
||||||
|
"ktem.reasoning.simple.FullQAPipeline",
|
||||||
|
"ktem.reasoning.simple.FullDecomposeQAPipeline",
|
||||||
|
"ktem.reasoning.react.ReactAgentPipeline",
|
||||||
|
"ktem.reasoning.rewoo.RewooAgentPipeline",
|
||||||
|
]
|
||||||
|
KH_REASONINGS_USE_MULTIMODAL = False
|
||||||
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
|
KH_VLM_ENDPOINT = "{0}/openai/deployments/{1}/chat/completions?api-version={2}".format(
|
||||||
config("AZURE_OPENAI_ENDPOINT", default=""),
|
config("AZURE_OPENAI_ENDPOINT", default=""),
|
||||||
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4-vision"),
|
config("OPENAI_VISION_DEPLOYMENT_NAME", default="gpt-4o"),
|
||||||
config("OPENAI_API_VERSION", default=""),
|
config("OPENAI_API_VERSION", default=""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SETTINGS_APP = {
|
SETTINGS_APP: dict[str, dict] = {}
|
||||||
"lang": {
|
|
||||||
"name": "Language",
|
|
||||||
"value": "en",
|
|
||||||
"choices": [("English", "en"), ("Japanese", "ja")],
|
|
||||||
"component": "dropdown",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
SETTINGS_REASONING = {
|
SETTINGS_REASONING = {
|
||||||
|
@ -187,17 +199,42 @@ SETTINGS_REASONING = {
|
||||||
"lang": {
|
"lang": {
|
||||||
"name": "Language",
|
"name": "Language",
|
||||||
"value": "en",
|
"value": "en",
|
||||||
"choices": [("English", "en"), ("Japanese", "ja")],
|
"choices": [("English", "en"), ("Japanese", "ja"), ("Vietnamese", "vi")],
|
||||||
"component": "dropdown",
|
"component": "dropdown",
|
||||||
},
|
},
|
||||||
|
"max_context_length": {
|
||||||
|
"name": "Max context length (LLM)",
|
||||||
|
"value": 32000,
|
||||||
|
"component": "number",
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
KH_INDEX_TYPES = ["ktem.index.file.FileIndex"]
|
KH_INDEX_TYPES = [
|
||||||
|
"ktem.index.file.FileIndex",
|
||||||
|
"ktem.index.file.graph.GraphRAGIndex",
|
||||||
|
]
|
||||||
KH_INDICES = [
|
KH_INDICES = [
|
||||||
{
|
{
|
||||||
"name": "File",
|
"name": "File",
|
||||||
"config": {},
|
"config": {
|
||||||
|
"supported_file_types": (
|
||||||
|
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
||||||
|
".pptx, .csv, .html, .mhtml, .txt, .zip"
|
||||||
|
),
|
||||||
|
"private": False,
|
||||||
|
},
|
||||||
"index_type": "ktem.index.file.FileIndex",
|
"index_type": "ktem.index.file.FileIndex",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "GraphRAG",
|
||||||
|
"config": {
|
||||||
|
"supported_file_types": (
|
||||||
|
".png, .jpeg, .jpg, .tiff, .tif, .pdf, .xls, .xlsx, .doc, .docx, "
|
||||||
|
".pptx, .csv, .html, .mhtml, .txt, .zip"
|
||||||
|
),
|
||||||
|
"private": False,
|
||||||
|
},
|
||||||
|
"index_type": "ktem.index.file.graph.GraphRAGIndex",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
|
|
|
@ -39,16 +39,11 @@ class ReactAgent(BaseAgent):
|
||||||
)
|
)
|
||||||
max_iterations: int = 5
|
max_iterations: int = 5
|
||||||
strict_decode: bool = False
|
strict_decode: bool = False
|
||||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
max_context_length: int = Param(
|
||||||
chunk_size=800,
|
default=3000,
|
||||||
chunk_overlap=0,
|
help="Max context length for each tool output.",
|
||||||
separator=" ",
|
|
||||||
tokenizer=partial(
|
|
||||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
|
||||||
allowed_special=set(),
|
|
||||||
disallowed_special="all",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
trim_func: TokenSplitter | None = None
|
||||||
|
|
||||||
def _compose_plugin_description(self) -> str:
|
def _compose_plugin_description(self) -> str:
|
||||||
"""
|
"""
|
||||||
|
@ -149,14 +144,28 @@ class ReactAgent(BaseAgent):
|
||||||
function_map[plugin.name] = plugin
|
function_map[plugin.name] = plugin
|
||||||
return function_map
|
return function_map
|
||||||
|
|
||||||
def _trim(self, text: str) -> str:
|
def _trim(self, text: str | Document) -> str:
|
||||||
"""
|
"""
|
||||||
Trim the text to the maximum token length.
|
Trim the text to the maximum token length.
|
||||||
"""
|
"""
|
||||||
|
evidence_trim_func = (
|
||||||
|
self.trim_func
|
||||||
|
if self.trim_func
|
||||||
|
else TokenSplitter(
|
||||||
|
chunk_size=self.max_context_length,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
tokenizer=partial(
|
||||||
|
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||||
|
allowed_special=set(),
|
||||||
|
disallowed_special="all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
if isinstance(text, str):
|
if isinstance(text, str):
|
||||||
texts = self.trim_func([Document(text=text)])
|
texts = evidence_trim_func([Document(text=text)])
|
||||||
elif isinstance(text, Document):
|
elif isinstance(text, Document):
|
||||||
texts = self.trim_func([text])
|
texts = evidence_trim_func([text])
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid text type to trim")
|
raise ValueError("Invalid text type to trim")
|
||||||
trim_text = texts[0].text
|
trim_text = texts[0].text
|
||||||
|
|
|
@ -39,16 +39,11 @@ class RewooAgent(BaseAgent):
|
||||||
examples: dict[str, str | list[str]] = Param(
|
examples: dict[str, str | list[str]] = Param(
|
||||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||||
)
|
)
|
||||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
max_context_length: int = Param(
|
||||||
chunk_size=3000,
|
default=3000,
|
||||||
chunk_overlap=0,
|
help="Max context length for each tool output.",
|
||||||
separator=" ",
|
|
||||||
tokenizer=partial(
|
|
||||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
|
||||||
allowed_special=set(),
|
|
||||||
disallowed_special="all",
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
trim_func: TokenSplitter | None = None
|
||||||
|
|
||||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||||
def planner(self):
|
def planner(self):
|
||||||
|
@ -248,8 +243,22 @@ class RewooAgent(BaseAgent):
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def _trim_evidence(self, evidence: str):
|
def _trim_evidence(self, evidence: str):
|
||||||
|
evidence_trim_func = (
|
||||||
|
self.trim_func
|
||||||
|
if self.trim_func
|
||||||
|
else TokenSplitter(
|
||||||
|
chunk_size=self.max_context_length,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
tokenizer=partial(
|
||||||
|
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||||
|
allowed_special=set(),
|
||||||
|
disallowed_special="all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
if evidence:
|
if evidence:
|
||||||
texts = self.trim_func([Document(text=evidence)])
|
texts = evidence_trim_func([Document(text=evidence)])
|
||||||
evidence = texts[0].text
|
evidence = texts[0].text
|
||||||
logging.info(f"len (trimmed): {len(evidence)}")
|
logging.info(f"len (trimmed): {len(evidence)}")
|
||||||
return evidence
|
return evidence
|
||||||
|
@ -317,6 +326,14 @@ class RewooAgent(BaseAgent):
|
||||||
)
|
)
|
||||||
|
|
||||||
print("Planner output:", planner_text_output)
|
print("Planner output:", planner_text_output)
|
||||||
|
# output planner to info panel
|
||||||
|
yield AgentOutput(
|
||||||
|
text="",
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status="thinking",
|
||||||
|
intermediate_steps=[{"planner_log": planner_text_output}],
|
||||||
|
)
|
||||||
|
|
||||||
# Work
|
# Work
|
||||||
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||||
planner_evidences, evidence_level
|
planner_evidences, evidence_level
|
||||||
|
@ -326,7 +343,9 @@ class RewooAgent(BaseAgent):
|
||||||
worker_log += f"{plan}: {plans[plan]}\n"
|
worker_log += f"{plan}: {plans[plan]}\n"
|
||||||
current_progress = f"{plan}: {plans[plan]}\n"
|
current_progress = f"{plan}: {plans[plan]}\n"
|
||||||
for e in plan_to_es[plan]:
|
for e in plan_to_es[plan]:
|
||||||
|
worker_log += f"#Action: {planner_evidences.get(e, None)}\n"
|
||||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||||
|
current_progress += f"#Action: {planner_evidences.get(e, None)}\n"
|
||||||
current_progress += f"{e}: {worker_evidences[e]}\n"
|
current_progress += f"{e}: {worker_evidences[e]}\n"
|
||||||
|
|
||||||
yield AgentOutput(
|
yield AgentOutput(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from typing import AnyStr, Optional, Type
|
from typing import AnyStr, Optional, Type
|
||||||
from urllib.error import HTTPError
|
from urllib.error import HTTPError
|
||||||
|
|
||||||
from langchain.utilities import SerpAPIWrapper
|
from langchain_community.utilities import SerpAPIWrapper
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
|
|
@ -22,12 +22,16 @@ class LLMTool(BaseTool):
|
||||||
)
|
)
|
||||||
llm: BaseLLM
|
llm: BaseLLM
|
||||||
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
||||||
|
dummy_mode: bool = True
|
||||||
|
|
||||||
def _run_tool(self, query: AnyStr) -> str:
|
def _run_tool(self, query: AnyStr) -> str:
|
||||||
output = None
|
output = None
|
||||||
try:
|
try:
|
||||||
response = self.llm(query)
|
if not self.dummy_mode:
|
||||||
|
response = self.llm(query)
|
||||||
|
else:
|
||||||
|
response = None
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ToolException("LLM Tool call failed")
|
raise ToolException("LLM Tool call failed")
|
||||||
output = response.text
|
output = response.text if response else "<->"
|
||||||
return output
|
return output
|
||||||
|
|
|
@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
|
||||||
from langchain.schema.messages import AIMessage as LCAIMessage
|
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||||
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||||
from langchain.schema.messages import SystemMessage as LCSystemMessage
|
from langchain.schema.messages import SystemMessage as LCSystemMessage
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.core.bridge.pydantic import Field
|
||||||
from llama_index.schema import Document as BaseDocument
|
from llama_index.core.schema import Document as BaseDocument
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from haystack.schema import Document as HaystackDocument
|
from haystack.schema import Document as HaystackDocument
|
||||||
|
@ -38,7 +38,7 @@ class Document(BaseDocument):
|
||||||
|
|
||||||
content: Any = None
|
content: Any = None
|
||||||
source: Optional[str] = None
|
source: Optional[str] = None
|
||||||
channel: Optional[Literal["chat", "info", "index", "debug"]] = None
|
channel: Optional[Literal["chat", "info", "index", "debug", "plot"]] = None
|
||||||
|
|
||||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||||
if content is None:
|
if content is None:
|
||||||
|
@ -140,6 +140,7 @@ class LLMInterface(AIMessage):
|
||||||
total_cost: float = 0
|
total_cost: float = 0
|
||||||
logits: list[list[float]] = Field(default_factory=list)
|
logits: list[list[float]] = Field(default_factory=list)
|
||||||
messages: list[AIMessage] = Field(default_factory=list)
|
messages: list[AIMessage] = Field(default_factory=list)
|
||||||
|
logprobs: list[float] = []
|
||||||
|
|
||||||
|
|
||||||
class ExtractorOutput(Document):
|
class ExtractorOutput(Document):
|
||||||
|
|
|
@ -133,9 +133,7 @@ def construct_chat_ui(
|
||||||
label="Output file", show_label=True, height=100
|
label="Output file", show_label=True, height=100
|
||||||
)
|
)
|
||||||
export_btn = gr.Button("Export")
|
export_btn = gr.Button("Export")
|
||||||
export_btn.click(
|
export_btn.click(func_export_to_excel, inputs=[], outputs=exported_file)
|
||||||
func_export_to_excel, inputs=None, outputs=exported_file
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
|
|
|
@ -91,7 +91,7 @@ def construct_pipeline_ui(
|
||||||
save_btn.click(func_save, inputs=params, outputs=history_dataframe)
|
save_btn.click(func_save, inputs=params, outputs=history_dataframe)
|
||||||
load_params_btn = gr.Button("Reload params")
|
load_params_btn = gr.Button("Reload params")
|
||||||
load_params_btn.click(
|
load_params_btn.click(
|
||||||
func_load_params, inputs=None, outputs=history_dataframe
|
func_load_params, inputs=[], outputs=history_dataframe
|
||||||
)
|
)
|
||||||
history_dataframe.render()
|
history_dataframe.render()
|
||||||
history_dataframe.select(
|
history_dataframe.select(
|
||||||
|
@ -103,7 +103,7 @@ def construct_pipeline_ui(
|
||||||
export_btn = gr.Button(
|
export_btn = gr.Button(
|
||||||
"Export (Result will be in Exported file next to Output)"
|
"Export (Result will be in Exported file next to Output)"
|
||||||
)
|
)
|
||||||
export_btn.click(func_export, inputs=None, outputs=exported_file)
|
export_btn.click(func_export, inputs=[], outputs=exported_file)
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
if params:
|
if params:
|
||||||
|
|
|
@ -1,5 +1,15 @@
|
||||||
|
from itertools import islice
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import openai
|
||||||
|
import tiktoken
|
||||||
|
from tenacity import (
|
||||||
|
retry,
|
||||||
|
retry_if_not_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_random_exponential,
|
||||||
|
)
|
||||||
from theflow.utils.modules import import_dotted_string
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
from kotaemon.base import Param
|
from kotaemon.base import Param
|
||||||
|
@ -7,6 +17,24 @@ from kotaemon.base import Param
|
||||||
from .base import BaseEmbeddings, Document, DocumentWithEmbedding
|
from .base import BaseEmbeddings, Document, DocumentWithEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
def split_text_by_chunk_size(text: str, chunk_size: int) -> list[list[int]]:
|
||||||
|
"""Split the text into chunks of a given size
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: text to split
|
||||||
|
chunk_size: size of each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list of chunks (as tokens)
|
||||||
|
"""
|
||||||
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
|
tokens = iter(encoding.encode(text))
|
||||||
|
result = []
|
||||||
|
while chunk := list(islice(tokens, chunk_size)):
|
||||||
|
result.append(chunk)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class BaseOpenAIEmbeddings(BaseEmbeddings):
|
class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||||
"""Base interface for OpenAI embedding model, using the openai library.
|
"""Base interface for OpenAI embedding model, using the openai library.
|
||||||
|
|
||||||
|
@ -32,6 +60,9 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||||
"Only supported in `text-embedding-3` and later models."
|
"Only supported in `text-embedding-3` and later models."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
context_length: Optional[int] = Param(
|
||||||
|
None, help="The maximum context length of the embedding model"
|
||||||
|
)
|
||||||
|
|
||||||
@Param.auto(depends_on=["max_retries"])
|
@Param.auto(depends_on=["max_retries"])
|
||||||
def max_retries_(self):
|
def max_retries_(self):
|
||||||
|
@ -56,16 +87,42 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||||
def invoke(
|
def invoke(
|
||||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||||
) -> list[DocumentWithEmbedding]:
|
) -> list[DocumentWithEmbedding]:
|
||||||
input_ = self.prepare_input(text)
|
input_doc = self.prepare_input(text)
|
||||||
client = self.prepare_client(async_version=False)
|
client = self.prepare_client(async_version=False)
|
||||||
resp = self.openai_response(
|
|
||||||
client, input=[_.text if _.text else " " for _ in input_], **kwargs
|
input_: list[str | list[int]] = []
|
||||||
).dict()
|
splitted_indices = {}
|
||||||
output_ = sorted(resp["data"], key=lambda x: x["index"])
|
for idx, text in enumerate(input_doc):
|
||||||
return [
|
if self.context_length:
|
||||||
DocumentWithEmbedding(embedding=o["embedding"], content=i)
|
chunks = split_text_by_chunk_size(text.text or " ", self.context_length)
|
||||||
for i, o in zip(input_, output_)
|
splitted_indices[idx] = (len(input_), len(input_) + len(chunks))
|
||||||
]
|
input_.extend(chunks)
|
||||||
|
else:
|
||||||
|
splitted_indices[idx] = (len(input_), len(input_) + 1)
|
||||||
|
input_.append(text.text)
|
||||||
|
|
||||||
|
resp = self.openai_response(client, input=input_, **kwargs).dict()
|
||||||
|
output_ = list(sorted(resp["data"], key=lambda x: x["index"]))
|
||||||
|
|
||||||
|
output = []
|
||||||
|
for idx, doc in enumerate(input_doc):
|
||||||
|
embs = output_[splitted_indices[idx][0] : splitted_indices[idx][1]]
|
||||||
|
if len(embs) == 1:
|
||||||
|
output.append(
|
||||||
|
DocumentWithEmbedding(embedding=embs[0]["embedding"], content=doc)
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_lens = [
|
||||||
|
len(_)
|
||||||
|
for _ in input_[splitted_indices[idx][0] : splitted_indices[idx][1]]
|
||||||
|
]
|
||||||
|
vs: list[list[float]] = [_["embedding"] for _ in embs]
|
||||||
|
emb = np.average(vs, axis=0, weights=chunk_lens)
|
||||||
|
emb = emb / np.linalg.norm(emb)
|
||||||
|
output.append(DocumentWithEmbedding(embedding=emb.tolist(), content=doc))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||||
|
@ -118,6 +175,13 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||||
|
|
||||||
return OpenAI(**params)
|
return OpenAI(**params)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_not_exception_type(
|
||||||
|
(openai.NotFoundError, openai.BadRequestError)
|
||||||
|
),
|
||||||
|
wait=wait_random_exponential(min=1, max=40),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
def openai_response(self, client, **kwargs):
|
def openai_response(self, client, **kwargs):
|
||||||
"""Get the openai response"""
|
"""Get the openai response"""
|
||||||
params: dict = {
|
params: dict = {
|
||||||
|
@ -174,6 +238,13 @@ class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||||
|
|
||||||
return AzureOpenAI(**params)
|
return AzureOpenAI(**params)
|
||||||
|
|
||||||
|
@retry(
|
||||||
|
retry=retry_if_not_exception_type(
|
||||||
|
(openai.NotFoundError, openai.BadRequestError)
|
||||||
|
),
|
||||||
|
wait=wait_random_exponential(min=1, max=40),
|
||||||
|
stop=stop_after_attempt(6),
|
||||||
|
)
|
||||||
def openai_response(self, client, **kwargs):
|
def openai_response(self, client, **kwargs):
|
||||||
"""Get the openai response"""
|
"""Get the openai response"""
|
||||||
params: dict = {
|
params: dict = {
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
from llama_index.node_parser.interface import NodeParser
|
from llama_index.core.node_parser.interface import NodeParser
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ class LlamaIndexDocTransformerMixin:
|
||||||
Example:
|
Example:
|
||||||
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
|
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
|
||||||
def _get_li_class(self):
|
def _get_li_class(self):
|
||||||
from llama_index.text_splitter import TokenTextSplitter
|
from llama_index.core.text_splitter import TokenTextSplitter
|
||||||
return TokenTextSplitter
|
return TokenTextSplitter
|
||||||
|
|
||||||
To use this mixin, please:
|
To use this mixin, please:
|
||||||
|
|
|
@ -15,7 +15,7 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||||
super().__init__(llm=llm, nodes=nodes, **params)
|
super().__init__(llm=llm, nodes=nodes, **params)
|
||||||
|
|
||||||
def _get_li_class(self):
|
def _get_li_class(self):
|
||||||
from llama_index.extractors import TitleExtractor
|
from llama_index.core.extractors import TitleExtractor
|
||||||
|
|
||||||
return TitleExtractor
|
return TitleExtractor
|
||||||
|
|
||||||
|
@ -30,6 +30,6 @@ class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||||
super().__init__(llm=llm, summaries=summaries, **params)
|
super().__init__(llm=llm, summaries=summaries, **params)
|
||||||
|
|
||||||
def _get_li_class(self):
|
def _get_li_class(self):
|
||||||
from llama_index.extractors import SummaryExtractor
|
from llama_index.core.extractors import SummaryExtractor
|
||||||
|
|
||||||
return SummaryExtractor
|
return SummaryExtractor
|
||||||
|
|
|
@ -1,27 +1,42 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
from llama_index.readers import PDFReader
|
from decouple import config
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent, Document, Param
|
from kotaemon.base import BaseComponent, Document, Param
|
||||||
from kotaemon.indices.extractors import BaseDocParser
|
from kotaemon.indices.extractors import BaseDocParser
|
||||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||||
from kotaemon.loaders import (
|
from kotaemon.loaders import (
|
||||||
AdobeReader,
|
AdobeReader,
|
||||||
|
AzureAIDocumentIntelligenceLoader,
|
||||||
DirectoryReader,
|
DirectoryReader,
|
||||||
HtmlReader,
|
HtmlReader,
|
||||||
MathpixPDFReader,
|
MathpixPDFReader,
|
||||||
MhtmlReader,
|
MhtmlReader,
|
||||||
OCRReader,
|
OCRReader,
|
||||||
PandasExcelReader,
|
PandasExcelReader,
|
||||||
|
PDFThumbnailReader,
|
||||||
UnstructuredReader,
|
UnstructuredReader,
|
||||||
)
|
)
|
||||||
|
|
||||||
unstructured = UnstructuredReader()
|
unstructured = UnstructuredReader()
|
||||||
|
adobe_reader = AdobeReader()
|
||||||
|
azure_reader = AzureAIDocumentIntelligenceLoader(
|
||||||
|
endpoint=str(config("AZURE_DI_ENDPOINT", default="")),
|
||||||
|
credential=str(config("AZURE_DI_CREDENTIAL", default="")),
|
||||||
|
cache_dir=getattr(flowsettings, "KH_MARKDOWN_OUTPUT_DIR", None),
|
||||||
|
)
|
||||||
|
adobe_reader.vlm_endpoint = azure_reader.vlm_endpoint = getattr(
|
||||||
|
flowsettings, "KH_VLM_ENDPOINT", ""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
||||||
".xlsx": PandasExcelReader(),
|
".xlsx": PandasExcelReader(),
|
||||||
".docx": unstructured,
|
".docx": unstructured,
|
||||||
|
".pptx": unstructured,
|
||||||
".xls": unstructured,
|
".xls": unstructured,
|
||||||
".doc": unstructured,
|
".doc": unstructured,
|
||||||
".html": HtmlReader(),
|
".html": HtmlReader(),
|
||||||
|
@ -31,7 +46,7 @@ KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
||||||
".jpg": unstructured,
|
".jpg": unstructured,
|
||||||
".tiff": unstructured,
|
".tiff": unstructured,
|
||||||
".tif": unstructured,
|
".tif": unstructured,
|
||||||
".pdf": PDFReader(),
|
".pdf": PDFThumbnailReader(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -103,7 +103,9 @@ class CitationPipeline(BaseComponent):
|
||||||
print("CitationPipeline: invoking LLM")
|
print("CitationPipeline: invoking LLM")
|
||||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||||
print("CitationPipeline: finish invoking LLM")
|
print("CitationPipeline: finish invoking LLM")
|
||||||
if not llm_output.messages:
|
if not llm_output.messages or not llm_output.additional_kwargs.get(
|
||||||
|
"tool_calls"
|
||||||
|
):
|
||||||
return None
|
return None
|
||||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||||
"arguments"
|
"arguments"
|
||||||
|
|
|
@ -1,5 +1,13 @@
|
||||||
from .base import BaseReranking
|
from .base import BaseReranking
|
||||||
from .cohere import CohereReranking
|
from .cohere import CohereReranking
|
||||||
from .llm import LLMReranking
|
from .llm import LLMReranking
|
||||||
|
from .llm_scoring import LLMScoring
|
||||||
|
from .llm_trulens import LLMTrulensScoring
|
||||||
|
|
||||||
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]
|
__all__ = [
|
||||||
|
"CohereReranking",
|
||||||
|
"LLMReranking",
|
||||||
|
"LLMScoring",
|
||||||
|
"BaseReranking",
|
||||||
|
"LLMTrulensScoring",
|
||||||
|
]
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
from decouple import config
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
@ -9,8 +9,7 @@ from .base import BaseReranking
|
||||||
|
|
||||||
class CohereReranking(BaseReranking):
|
class CohereReranking(BaseReranking):
|
||||||
model_name: str = "rerank-multilingual-v2.0"
|
model_name: str = "rerank-multilingual-v2.0"
|
||||||
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
|
cohere_api_key: str = config("COHERE_API_KEY", "")
|
||||||
top_k: int = 1
|
|
||||||
|
|
||||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||||
"""Use Cohere Reranker model to re-order documents
|
"""Use Cohere Reranker model to re-order documents
|
||||||
|
@ -22,6 +21,10 @@ class CohereReranking(BaseReranking):
|
||||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not self.cohere_api_key:
|
||||||
|
print("Cohere API key not found. Skipping reranking.")
|
||||||
|
return documents
|
||||||
|
|
||||||
cohere_client = cohere.Client(self.cohere_api_key)
|
cohere_client = cohere.Client(self.cohere_api_key)
|
||||||
compressed_docs: list[Document] = []
|
compressed_docs: list[Document] = []
|
||||||
|
|
||||||
|
@ -29,12 +32,13 @@ class CohereReranking(BaseReranking):
|
||||||
return compressed_docs
|
return compressed_docs
|
||||||
|
|
||||||
_docs = [d.content for d in documents]
|
_docs = [d.content for d in documents]
|
||||||
results = cohere_client.rerank(
|
response = cohere_client.rerank(
|
||||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
model=self.model_name, query=query, documents=_docs
|
||||||
)
|
)
|
||||||
for r in results:
|
print("Cohere score", [r.relevance_score for r in response.results])
|
||||||
|
for r in response.results:
|
||||||
doc = documents[r.index]
|
doc = documents[r.index]
|
||||||
doc.metadata["relevance_score"] = r.relevance_score
|
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||||
compressed_docs.append(doc)
|
compressed_docs.append(doc)
|
||||||
|
|
||||||
return compressed_docs
|
return compressed_docs
|
||||||
|
|
54
libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
from .llm import LLMReranking
|
||||||
|
|
||||||
|
|
||||||
|
class LLMScoring(LLMReranking):
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
documents: list[Document],
|
||||||
|
query: str,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Filter down documents based on their relevance to the query."""
|
||||||
|
filtered_docs: list[Document] = []
|
||||||
|
output_parser = BooleanOutputParser()
|
||||||
|
|
||||||
|
if self.concurrent:
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
|
for doc in documents:
|
||||||
|
_prompt = self.prompt_template.populate(
|
||||||
|
question=query, context=doc.get_content()
|
||||||
|
)
|
||||||
|
futures.append(executor.submit(lambda: self.llm(_prompt)))
|
||||||
|
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
for doc in documents:
|
||||||
|
_prompt = self.prompt_template.populate(
|
||||||
|
question=query, context=doc.get_content()
|
||||||
|
)
|
||||||
|
results.append(self.llm(_prompt))
|
||||||
|
|
||||||
|
for result, doc in zip(results, documents):
|
||||||
|
score = np.exp(np.average(result.logprobs))
|
||||||
|
include_doc = output_parser.parse(result.text)
|
||||||
|
if include_doc:
|
||||||
|
doc.metadata["llm_reranking_score"] = score
|
||||||
|
else:
|
||||||
|
doc.metadata["llm_reranking_score"] = 1 - score
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
# prevent returning empty result
|
||||||
|
if len(filtered_docs) == 0:
|
||||||
|
filtered_docs = documents[: self.top_k]
|
||||||
|
|
||||||
|
return filtered_docs
|
182
libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
from kotaemon.base import Document, HumanMessage, SystemMessage
|
||||||
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
|
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||||
|
|
||||||
|
from .llm import LLMReranking
|
||||||
|
|
||||||
|
SYSTEM_PROMPT_TEMPLATE = PromptTemplate(
|
||||||
|
"""You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION.
|
||||||
|
Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant.
|
||||||
|
|
||||||
|
A few additional scoring guidelines:
|
||||||
|
|
||||||
|
- Long CONTEXTS should score equally well as short CONTEXTS.
|
||||||
|
|
||||||
|
- RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION.
|
||||||
|
|
||||||
|
- RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION.
|
||||||
|
|
||||||
|
- CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE.
|
||||||
|
|
||||||
|
- CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE.
|
||||||
|
|
||||||
|
- CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE.
|
||||||
|
|
||||||
|
- CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10.
|
||||||
|
|
||||||
|
- Never elaborate.""" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
USER_PROMPT_TEMPLATE = PromptTemplate(
|
||||||
|
"""QUESTION: {question}
|
||||||
|
|
||||||
|
CONTEXT: {context}
|
||||||
|
|
||||||
|
RELEVANCE: """
|
||||||
|
) # noqa
|
||||||
|
|
||||||
|
PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)")
|
||||||
|
"""Regex that matches integers."""
|
||||||
|
|
||||||
|
MAX_CONTEXT_LEN = 7500
|
||||||
|
|
||||||
|
|
||||||
|
def validate_rating(rating) -> int:
|
||||||
|
"""Validate a rating is between 0 and 10."""
|
||||||
|
|
||||||
|
if not 0 <= rating <= 10:
|
||||||
|
raise ValueError("Rating must be between 0 and 10")
|
||||||
|
|
||||||
|
return rating
|
||||||
|
|
||||||
|
|
||||||
|
def re_0_10_rating(s: str) -> int:
|
||||||
|
"""Extract a 0-10 rating from a string.
|
||||||
|
|
||||||
|
If the string does not match an integer or matches an integer outside the
|
||||||
|
0-10 range, raises an error instead. If multiple numbers are found within
|
||||||
|
the expected 0-10 range, the smallest is returned.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
s: String to extract rating from.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Extracted rating.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ParseError: If no integers between 0 and 10 are found in the string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
matches = PATTERN_INTEGER.findall(s)
|
||||||
|
if not matches:
|
||||||
|
raise AssertionError
|
||||||
|
|
||||||
|
vals = set()
|
||||||
|
for match in matches:
|
||||||
|
try:
|
||||||
|
vals.add(validate_rating(int(match)))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if not vals:
|
||||||
|
raise AssertionError
|
||||||
|
|
||||||
|
# Min to handle cases like "The rating is 8 out of 10."
|
||||||
|
return min(vals)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMTrulensScoring(LLMReranking):
|
||||||
|
llm: BaseLLM
|
||||||
|
system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE
|
||||||
|
user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE
|
||||||
|
concurrent: bool = True
|
||||||
|
normalize: float = 10
|
||||||
|
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||||
|
chunk_size=MAX_CONTEXT_LEN,
|
||||||
|
chunk_overlap=0,
|
||||||
|
separator=" ",
|
||||||
|
tokenizer=partial(
|
||||||
|
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||||
|
allowed_special=set(),
|
||||||
|
disallowed_special="all",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
documents: list[Document],
|
||||||
|
query: str,
|
||||||
|
) -> list[Document]:
|
||||||
|
"""Filter down documents based on their relevance to the query."""
|
||||||
|
filtered_docs = []
|
||||||
|
|
||||||
|
documents = sorted(documents, key=lambda doc: doc.get_content())
|
||||||
|
if self.concurrent:
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
futures = []
|
||||||
|
for doc in documents:
|
||||||
|
chunked_doc_content = self.trim_func(
|
||||||
|
[
|
||||||
|
Document(content=doc.get_content())
|
||||||
|
# skip metadata which cause troubles
|
||||||
|
]
|
||||||
|
)[0].text
|
||||||
|
|
||||||
|
messages = []
|
||||||
|
messages.append(
|
||||||
|
SystemMessage(self.system_prompt_template.populate())
|
||||||
|
)
|
||||||
|
messages.append(
|
||||||
|
HumanMessage(
|
||||||
|
self.user_prompt_template.populate(
|
||||||
|
question=query, context=chunked_doc_content
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def llm_call():
|
||||||
|
return self.llm(messages).text
|
||||||
|
|
||||||
|
futures.append(executor.submit(llm_call))
|
||||||
|
|
||||||
|
results = [future.result() for future in futures]
|
||||||
|
else:
|
||||||
|
results = []
|
||||||
|
for doc in documents:
|
||||||
|
messages = []
|
||||||
|
messages.append(SystemMessage(self.system_prompt_template.populate()))
|
||||||
|
messages.append(
|
||||||
|
SystemMessage(
|
||||||
|
self.user_prompt_template.populate(
|
||||||
|
question=query, context=doc.get_content()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
results.append(self.llm(messages).text)
|
||||||
|
|
||||||
|
# use Boolean parser to extract relevancy output from LLM
|
||||||
|
results = [
|
||||||
|
(r_idx, float(re_0_10_rating(result)) / self.normalize)
|
||||||
|
for r_idx, result in enumerate(results)
|
||||||
|
]
|
||||||
|
results.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
for r_idx, score in results:
|
||||||
|
doc = documents[r_idx]
|
||||||
|
doc.metadata["llm_trulens_score"] = score
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
|
||||||
|
print(
|
||||||
|
"LLM rerank scores",
|
||||||
|
[doc.metadata["llm_trulens_score"] for doc in filtered_docs],
|
||||||
|
)
|
||||||
|
|
||||||
|
return filtered_docs
|
|
@ -23,7 +23,7 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_li_class(self):
|
def _get_li_class(self):
|
||||||
from llama_index.text_splitter import TokenTextSplitter
|
from llama_index.core.text_splitter import TokenTextSplitter
|
||||||
|
|
||||||
return TokenTextSplitter
|
return TokenTextSplitter
|
||||||
|
|
||||||
|
@ -44,6 +44,6 @@ class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_li_class(self):
|
def _get_li_class(self):
|
||||||
from llama_index.node_parser import SentenceWindowNodeParser
|
from llama_index.core.node_parser import SentenceWindowNodeParser
|
||||||
|
|
||||||
return SentenceWindowNodeParser
|
return SentenceWindowNodeParser
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional, Sequence, cast
|
from typing import Optional, Sequence, cast
|
||||||
|
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||||
from kotaemon.embeddings import BaseEmbeddings
|
from kotaemon.embeddings import BaseEmbeddings
|
||||||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||||
|
|
||||||
from .base import BaseIndexing, BaseRetrieval
|
from .base import BaseIndexing, BaseRetrieval
|
||||||
from .rankings import BaseReranking
|
from .rankings import BaseReranking, LLMReranking
|
||||||
|
|
||||||
VECTOR_STORE_FNAME = "vectorstore"
|
VECTOR_STORE_FNAME = "vectorstore"
|
||||||
DOC_STORE_FNAME = "docstore"
|
DOC_STORE_FNAME = "docstore"
|
||||||
|
@ -23,9 +27,11 @@ class VectorIndexing(BaseIndexing):
|
||||||
- List of texts
|
- List of texts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
cache_dir: Optional[str] = getattr(flowsettings, "KH_CHUNKS_OUTPUT_DIR", None)
|
||||||
vector_store: BaseVectorStore
|
vector_store: BaseVectorStore
|
||||||
doc_store: Optional[BaseDocumentStore] = None
|
doc_store: Optional[BaseDocumentStore] = None
|
||||||
embedding: BaseEmbeddings
|
embedding: BaseEmbeddings
|
||||||
|
count_: int = 0
|
||||||
|
|
||||||
def to_retrieval_pipeline(self, *args, **kwargs):
|
def to_retrieval_pipeline(self, *args, **kwargs):
|
||||||
"""Convert the indexing pipeline to a retrieval pipeline"""
|
"""Convert the indexing pipeline to a retrieval pipeline"""
|
||||||
|
@ -44,6 +50,52 @@ class VectorIndexing(BaseIndexing):
|
||||||
qa_pipeline=CitationQAPipeline(**kwargs),
|
qa_pipeline=CitationQAPipeline(**kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def write_chunk_to_file(self, docs: list[Document]):
|
||||||
|
# save the chunks content into markdown format
|
||||||
|
if self.cache_dir:
|
||||||
|
file_name = Path(docs[0].metadata["file_name"])
|
||||||
|
for i in range(len(docs)):
|
||||||
|
markdown_content = ""
|
||||||
|
if "page_label" in docs[i].metadata:
|
||||||
|
page_label = str(docs[i].metadata["page_label"])
|
||||||
|
markdown_content += f"Page label: {page_label}"
|
||||||
|
if "file_name" in docs[i].metadata:
|
||||||
|
filename = docs[i].metadata["file_name"]
|
||||||
|
markdown_content += f"\nFile name: {filename}"
|
||||||
|
if "section" in docs[i].metadata:
|
||||||
|
section = docs[i].metadata["section"]
|
||||||
|
markdown_content += f"\nSection: {section}"
|
||||||
|
if "type" in docs[i].metadata:
|
||||||
|
if docs[i].metadata["type"] == "image":
|
||||||
|
image_origin = docs[i].metadata["image_origin"]
|
||||||
|
image_origin = f'<p><img src="{image_origin}"></p>'
|
||||||
|
markdown_content += f"\nImage origin: {image_origin}"
|
||||||
|
if docs[i].text:
|
||||||
|
markdown_content += f"\ntext:\n{docs[i].text}"
|
||||||
|
|
||||||
|
with open(
|
||||||
|
Path(self.cache_dir) / f"{file_name.stem}_{self.count_+i}.md",
|
||||||
|
"w",
|
||||||
|
encoding="utf-8",
|
||||||
|
) as f:
|
||||||
|
f.write(markdown_content)
|
||||||
|
|
||||||
|
def add_to_docstore(self, docs: list[Document]):
|
||||||
|
if self.doc_store:
|
||||||
|
print("Adding documents to doc store")
|
||||||
|
self.doc_store.add(docs)
|
||||||
|
|
||||||
|
def add_to_vectorstore(self, docs: list[Document]):
|
||||||
|
# in case we want to skip embedding
|
||||||
|
if self.vector_store:
|
||||||
|
print(f"Getting embeddings for {len(docs)} nodes")
|
||||||
|
embeddings = self.embedding(docs)
|
||||||
|
print("Adding embeddings to vector store")
|
||||||
|
self.vector_store.add(
|
||||||
|
embeddings=embeddings,
|
||||||
|
ids=[t.doc_id for t in docs],
|
||||||
|
)
|
||||||
|
|
||||||
def run(self, text: str | list[str] | Document | list[Document]):
|
def run(self, text: str | list[str] | Document | list[Document]):
|
||||||
input_: list[Document] = []
|
input_: list[Document] = []
|
||||||
if not isinstance(text, list):
|
if not isinstance(text, list):
|
||||||
|
@ -59,16 +111,10 @@ class VectorIndexing(BaseIndexing):
|
||||||
f"Invalid input type {type(item)}, should be str or Document"
|
f"Invalid input type {type(item)}, should be str or Document"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"Getting embeddings for {len(input_)} nodes")
|
self.add_to_vectorstore(input_)
|
||||||
embeddings = self.embedding(input_)
|
self.add_to_docstore(input_)
|
||||||
print("Adding embeddings to vector store")
|
self.write_chunk_to_file(input_)
|
||||||
self.vector_store.add(
|
self.count_ += len(input_)
|
||||||
embeddings=embeddings,
|
|
||||||
ids=[t.doc_id for t in input_],
|
|
||||||
)
|
|
||||||
if self.doc_store:
|
|
||||||
print("Adding documents to doc store")
|
|
||||||
self.doc_store.add(input_)
|
|
||||||
|
|
||||||
|
|
||||||
class VectorRetrieval(BaseRetrieval):
|
class VectorRetrieval(BaseRetrieval):
|
||||||
|
@ -78,7 +124,16 @@ class VectorRetrieval(BaseRetrieval):
|
||||||
doc_store: Optional[BaseDocumentStore] = None
|
doc_store: Optional[BaseDocumentStore] = None
|
||||||
embedding: BaseEmbeddings
|
embedding: BaseEmbeddings
|
||||||
rerankers: Sequence[BaseReranking] = []
|
rerankers: Sequence[BaseReranking] = []
|
||||||
top_k: int = 1
|
top_k: int = 5
|
||||||
|
first_round_top_k_mult: int = 10
|
||||||
|
retrieval_mode: str = "hybrid" # vector, text, hybrid
|
||||||
|
|
||||||
|
def _filter_docs(
|
||||||
|
self, documents: list[RetrievedDocument], top_k: int | None = None
|
||||||
|
):
|
||||||
|
if top_k:
|
||||||
|
documents = documents[:top_k]
|
||||||
|
return documents
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, text: str | Document, top_k: Optional[int] = None, **kwargs
|
self, text: str | Document, top_k: Optional[int] = None, **kwargs
|
||||||
|
@ -95,24 +150,155 @@ class VectorRetrieval(BaseRetrieval):
|
||||||
if top_k is None:
|
if top_k is None:
|
||||||
top_k = self.top_k
|
top_k = self.top_k
|
||||||
|
|
||||||
|
do_extend = kwargs.pop("do_extend", False)
|
||||||
|
thumbnail_count = kwargs.pop("thumbnail_count", 3)
|
||||||
|
|
||||||
|
if do_extend:
|
||||||
|
top_k_first_round = top_k * self.first_round_top_k_mult
|
||||||
|
else:
|
||||||
|
top_k_first_round = top_k
|
||||||
|
|
||||||
if self.doc_store is None:
|
if self.doc_store is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"doc_store is not provided. Please provide a doc_store to "
|
"doc_store is not provided. Please provide a doc_store to "
|
||||||
"retrieve the documents"
|
"retrieve the documents"
|
||||||
)
|
)
|
||||||
|
|
||||||
emb: list[float] = self.embedding(text)[0].embedding
|
result: list[RetrievedDocument] = []
|
||||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
|
# TODO: should declare scope directly in the run params
|
||||||
docs = self.doc_store.get(ids)
|
scope = kwargs.pop("scope", None)
|
||||||
result = [
|
emb: list[float]
|
||||||
RetrievedDocument(**doc.to_dict(), score=score)
|
|
||||||
for doc, score in zip(docs, scores)
|
if self.retrieval_mode == "vector":
|
||||||
]
|
emb = self.embedding(text)[0].embedding
|
||||||
|
_, scores, ids = self.vector_store.query(
|
||||||
|
embedding=emb, top_k=top_k_first_round, **kwargs
|
||||||
|
)
|
||||||
|
docs = self.doc_store.get(ids)
|
||||||
|
result = [
|
||||||
|
RetrievedDocument(**doc.to_dict(), score=score)
|
||||||
|
for doc, score in zip(docs, scores)
|
||||||
|
]
|
||||||
|
elif self.retrieval_mode == "text":
|
||||||
|
query = text.text if isinstance(text, Document) else text
|
||||||
|
docs = self.doc_store.query(query, top_k=top_k_first_round, doc_ids=scope)
|
||||||
|
result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs]
|
||||||
|
elif self.retrieval_mode == "hybrid":
|
||||||
|
# similarity search section
|
||||||
|
emb = self.embedding(text)[0].embedding
|
||||||
|
vs_docs: list[RetrievedDocument] = []
|
||||||
|
vs_ids: list[str] = []
|
||||||
|
vs_scores: list[float] = []
|
||||||
|
|
||||||
|
def query_vectorstore():
|
||||||
|
nonlocal vs_docs
|
||||||
|
nonlocal vs_scores
|
||||||
|
nonlocal vs_ids
|
||||||
|
|
||||||
|
assert self.doc_store is not None
|
||||||
|
_, vs_scores, vs_ids = self.vector_store.query(
|
||||||
|
embedding=emb, top_k=top_k_first_round, **kwargs
|
||||||
|
)
|
||||||
|
if vs_ids:
|
||||||
|
vs_docs = self.doc_store.get(vs_ids)
|
||||||
|
|
||||||
|
# full-text search section
|
||||||
|
ds_docs: list[RetrievedDocument] = []
|
||||||
|
|
||||||
|
def query_docstore():
|
||||||
|
nonlocal ds_docs
|
||||||
|
|
||||||
|
assert self.doc_store is not None
|
||||||
|
query = text.text if isinstance(text, Document) else text
|
||||||
|
ds_docs = self.doc_store.query(
|
||||||
|
query, top_k=top_k_first_round, doc_ids=scope
|
||||||
|
)
|
||||||
|
|
||||||
|
vs_query_thread = threading.Thread(target=query_vectorstore)
|
||||||
|
ds_query_thread = threading.Thread(target=query_docstore)
|
||||||
|
|
||||||
|
vs_query_thread.start()
|
||||||
|
ds_query_thread.start()
|
||||||
|
|
||||||
|
vs_query_thread.join()
|
||||||
|
ds_query_thread.join()
|
||||||
|
|
||||||
|
result = [
|
||||||
|
RetrievedDocument(**doc.to_dict(), score=-1.0)
|
||||||
|
for doc in ds_docs
|
||||||
|
if doc not in vs_ids
|
||||||
|
]
|
||||||
|
result += [
|
||||||
|
RetrievedDocument(**doc.to_dict(), score=score)
|
||||||
|
for doc, score in zip(vs_docs, vs_scores)
|
||||||
|
]
|
||||||
|
print(f"Got {len(vs_docs)} from vectorstore")
|
||||||
|
print(f"Got {len(ds_docs)} from docstore")
|
||||||
|
|
||||||
# use additional reranker to re-order the document list
|
# use additional reranker to re-order the document list
|
||||||
if self.rerankers:
|
if self.rerankers and text:
|
||||||
for reranker in self.rerankers:
|
for reranker in self.rerankers:
|
||||||
|
# if reranker is LLMReranking, limit the document with top_k items only
|
||||||
|
if isinstance(reranker, LLMReranking):
|
||||||
|
result = self._filter_docs(result, top_k=top_k)
|
||||||
result = reranker(documents=result, query=text)
|
result = reranker(documents=result, query=text)
|
||||||
|
|
||||||
|
result = self._filter_docs(result, top_k=top_k)
|
||||||
|
print(f"Got raw {len(result)} retrieved documents")
|
||||||
|
|
||||||
|
# add page thumbnails to the result if exists
|
||||||
|
thumbnail_doc_ids: set[str] = set()
|
||||||
|
# we should copy the text from retrieved text chunk
|
||||||
|
# to the thumbnail to get relevant LLM score correctly
|
||||||
|
text_thumbnail_docs: dict[str, RetrievedDocument] = {}
|
||||||
|
|
||||||
|
non_thumbnail_docs = []
|
||||||
|
raw_thumbnail_docs = []
|
||||||
|
for doc in result:
|
||||||
|
if doc.metadata.get("type") == "thumbnail":
|
||||||
|
# change type to image to display on UI
|
||||||
|
doc.metadata["type"] = "image"
|
||||||
|
raw_thumbnail_docs.append(doc)
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
"thumbnail_doc_id" in doc.metadata
|
||||||
|
and len(thumbnail_doc_ids) < thumbnail_count
|
||||||
|
):
|
||||||
|
thumbnail_id = doc.metadata["thumbnail_doc_id"]
|
||||||
|
thumbnail_doc_ids.add(thumbnail_id)
|
||||||
|
text_thumbnail_docs[thumbnail_id] = doc
|
||||||
|
else:
|
||||||
|
non_thumbnail_docs.append(doc)
|
||||||
|
|
||||||
|
linked_thumbnail_docs = self.doc_store.get(list(thumbnail_doc_ids))
|
||||||
|
print(
|
||||||
|
"thumbnail docs",
|
||||||
|
len(linked_thumbnail_docs),
|
||||||
|
"non-thumbnail docs",
|
||||||
|
len(non_thumbnail_docs),
|
||||||
|
"raw-thumbnail docs",
|
||||||
|
len(raw_thumbnail_docs),
|
||||||
|
)
|
||||||
|
additional_docs = []
|
||||||
|
|
||||||
|
for thumbnail_doc in linked_thumbnail_docs:
|
||||||
|
text_doc = text_thumbnail_docs[thumbnail_doc.doc_id]
|
||||||
|
doc_dict = thumbnail_doc.to_dict()
|
||||||
|
doc_dict["_id"] = text_doc.doc_id
|
||||||
|
doc_dict["content"] = text_doc.content
|
||||||
|
doc_dict["metadata"]["type"] = "image"
|
||||||
|
for key in text_doc.metadata:
|
||||||
|
if key not in doc_dict["metadata"]:
|
||||||
|
doc_dict["metadata"][key] = text_doc.metadata[key]
|
||||||
|
|
||||||
|
additional_docs.append(RetrievedDocument(**doc_dict, score=text_doc.score))
|
||||||
|
|
||||||
|
result = additional_docs + non_thumbnail_docs
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
# return output from raw retrieved thumbnails
|
||||||
|
result = self._filter_docs(raw_thumbnail_docs, top_k=thumbnail_count)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ from .chats import (
|
||||||
ChatLLM,
|
ChatLLM,
|
||||||
ChatOpenAI,
|
ChatOpenAI,
|
||||||
EndpointChatLLM,
|
EndpointChatLLM,
|
||||||
|
LCAnthropicChat,
|
||||||
LCAzureChatOpenAI,
|
LCAzureChatOpenAI,
|
||||||
LCChatOpenAI,
|
LCChatOpenAI,
|
||||||
LlamaCppChat,
|
LlamaCppChat,
|
||||||
|
@ -27,6 +28,7 @@ __all__ = [
|
||||||
"SystemMessage",
|
"SystemMessage",
|
||||||
"AzureChatOpenAI",
|
"AzureChatOpenAI",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
"LCAnthropicChat",
|
||||||
"LCAzureChatOpenAI",
|
"LCAzureChatOpenAI",
|
||||||
"LCChatOpenAI",
|
"LCChatOpenAI",
|
||||||
"LlamaCppChat",
|
"LlamaCppChat",
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
from .base import ChatLLM
|
from .base import ChatLLM
|
||||||
from .endpoint_based import EndpointChatLLM
|
from .endpoint_based import EndpointChatLLM
|
||||||
from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
|
from .langchain_based import (
|
||||||
|
LCAnthropicChat,
|
||||||
|
LCAzureChatOpenAI,
|
||||||
|
LCChatMixin,
|
||||||
|
LCChatOpenAI,
|
||||||
|
)
|
||||||
from .llamacpp import LlamaCppChat
|
from .llamacpp import LlamaCppChat
|
||||||
from .openai import AzureChatOpenAI, ChatOpenAI
|
from .openai import AzureChatOpenAI, ChatOpenAI
|
||||||
|
|
||||||
|
@ -10,6 +15,7 @@ __all__ = [
|
||||||
"ChatLLM",
|
"ChatLLM",
|
||||||
"EndpointChatLLM",
|
"EndpointChatLLM",
|
||||||
"ChatOpenAI",
|
"ChatOpenAI",
|
||||||
|
"LCAnthropicChat",
|
||||||
"LCChatOpenAI",
|
"LCChatOpenAI",
|
||||||
"LCAzureChatOpenAI",
|
"LCAzureChatOpenAI",
|
||||||
"LCChatMixin",
|
"LCChatMixin",
|
||||||
|
|
|
@ -221,3 +221,27 @@ class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||||
from langchain.chat_models import AzureChatOpenAI
|
from langchain.chat_models import AzureChatOpenAI
|
||||||
|
|
||||||
return AzureChatOpenAI
|
return AzureChatOpenAI
|
||||||
|
|
||||||
|
|
||||||
|
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str | None = None,
|
||||||
|
model_name: str | None = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
**params,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
api_key=api_key,
|
||||||
|
model_name=model_name,
|
||||||
|
temperature=temperature,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_lc_class(self):
|
||||||
|
try:
|
||||||
|
from langchain_anthropic import ChatAnthropic
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install langchain-anthropic")
|
||||||
|
|
||||||
|
return ChatAnthropic
|
||||||
|
|
|
@ -159,6 +159,15 @@ class BaseChatOpenAI(ChatLLM):
|
||||||
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
||||||
"tool_calls"
|
"tool_calls"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if resp["choices"][0].get("logprobs") is None:
|
||||||
|
logprobs = []
|
||||||
|
else:
|
||||||
|
all_logprobs = resp["choices"][0]["logprobs"].get("content")
|
||||||
|
logprobs = (
|
||||||
|
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
|
||||||
|
)
|
||||||
|
|
||||||
output = LLMInterface(
|
output = LLMInterface(
|
||||||
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
||||||
content=resp["choices"][0]["message"]["content"] or "",
|
content=resp["choices"][0]["message"]["content"] or "",
|
||||||
|
@ -170,6 +179,7 @@ class BaseChatOpenAI(ChatLLM):
|
||||||
AIMessage(content=(_["message"]["content"]) or "")
|
AIMessage(content=(_["message"]["content"]) or "")
|
||||||
for _ in resp["choices"]
|
for _ in resp["choices"]
|
||||||
],
|
],
|
||||||
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
@ -216,11 +226,24 @@ class BaseChatOpenAI(ChatLLM):
|
||||||
client, messages=input_messages, stream=True, **kwargs
|
client, messages=input_messages, stream=True, **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
for chunk in resp:
|
for c in resp:
|
||||||
if not chunk.choices:
|
chunk = c.dict()
|
||||||
|
if not chunk["choices"]:
|
||||||
continue
|
continue
|
||||||
if chunk.choices[0].delta.content is not None:
|
if chunk["choices"][0]["delta"]["content"] is not None:
|
||||||
yield LLMInterface(content=chunk.choices[0].delta.content)
|
if chunk["choices"][0].get("logprobs") is None:
|
||||||
|
logprobs = []
|
||||||
|
else:
|
||||||
|
logprobs = [
|
||||||
|
logprob["logprob"]
|
||||||
|
for logprob in chunk["choices"][0]["logprobs"].get(
|
||||||
|
"content", []
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
yield LLMInterface(
|
||||||
|
content=chunk["choices"][0]["delta"]["content"], logprobs=logprobs
|
||||||
|
)
|
||||||
|
|
||||||
async def astream(
|
async def astream(
|
||||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||||
|
|
|
@ -3,10 +3,12 @@ from .azureai_document_intelligence_loader import AzureAIDocumentIntelligenceLoa
|
||||||
from .base import AutoReader, BaseReader
|
from .base import AutoReader, BaseReader
|
||||||
from .composite_loader import DirectoryReader
|
from .composite_loader import DirectoryReader
|
||||||
from .docx_loader import DocxReader
|
from .docx_loader import DocxReader
|
||||||
from .excel_loader import PandasExcelReader
|
from .excel_loader import ExcelReader, PandasExcelReader
|
||||||
from .html_loader import HtmlReader, MhtmlReader
|
from .html_loader import HtmlReader, MhtmlReader
|
||||||
from .mathpix_loader import MathpixPDFReader
|
from .mathpix_loader import MathpixPDFReader
|
||||||
from .ocr_loader import ImageReader, OCRReader
|
from .ocr_loader import ImageReader, OCRReader
|
||||||
|
from .pdf_loader import PDFThumbnailReader
|
||||||
|
from .txt_loader import TxtReader
|
||||||
from .unstructured_loader import UnstructuredReader
|
from .unstructured_loader import UnstructuredReader
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -14,6 +16,7 @@ __all__ = [
|
||||||
"AzureAIDocumentIntelligenceLoader",
|
"AzureAIDocumentIntelligenceLoader",
|
||||||
"BaseReader",
|
"BaseReader",
|
||||||
"PandasExcelReader",
|
"PandasExcelReader",
|
||||||
|
"ExcelReader",
|
||||||
"MathpixPDFReader",
|
"MathpixPDFReader",
|
||||||
"ImageReader",
|
"ImageReader",
|
||||||
"OCRReader",
|
"OCRReader",
|
||||||
|
@ -23,4 +26,6 @@ __all__ = [
|
||||||
"HtmlReader",
|
"HtmlReader",
|
||||||
"MhtmlReader",
|
"MhtmlReader",
|
||||||
"AdobeReader",
|
"AdobeReader",
|
||||||
|
"TxtReader",
|
||||||
|
"PDFThumbnailReader",
|
||||||
]
|
]
|
||||||
|
|
|
@ -6,7 +6,7 @@ from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from decouple import config
|
from decouple import config
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ class AdobeReader(BaseReader):
|
||||||
for page_number, table_content, table_caption in tables:
|
for page_number, table_content, table_caption in tables:
|
||||||
documents.append(
|
documents.append(
|
||||||
Document(
|
Document(
|
||||||
text=table_caption,
|
text=table_content,
|
||||||
metadata={
|
metadata={
|
||||||
"table_origin": table_content,
|
"table_origin": table_content,
|
||||||
"type": "table",
|
"type": "table",
|
||||||
|
|
|
@ -1,10 +1,56 @@
|
||||||
|
import base64
|
||||||
import os
|
import os
|
||||||
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from kotaemon.base import Document, Param
|
from kotaemon.base import Document, Param
|
||||||
|
|
||||||
from .base import BaseReader
|
from .base import BaseReader
|
||||||
|
from .utils.adobe import generate_single_figure_caption
|
||||||
|
|
||||||
|
|
||||||
|
def crop_image(file_path: Path, bbox: list[float], page_number: int = 0) -> Image.Image:
|
||||||
|
"""Crop the image based on the bounding box
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Path): path to the image file
|
||||||
|
bbox (list[float]): bounding box of the image (in percentage [x0, y0, x1, y1])
|
||||||
|
page_number (int, optional): page number of the image. Defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Image.Image: cropped image
|
||||||
|
"""
|
||||||
|
left, upper, right, lower = bbox
|
||||||
|
|
||||||
|
img: Image.Image
|
||||||
|
suffix = file_path.suffix.lower()
|
||||||
|
if suffix == ".pdf":
|
||||||
|
try:
|
||||||
|
import fitz
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install PyMuPDF: 'pip install PyMuPDF'")
|
||||||
|
|
||||||
|
doc = fitz.open(file_path)
|
||||||
|
page = doc.load_page(page_number)
|
||||||
|
pm = page.get_pixmap(dpi=150)
|
||||||
|
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
|
||||||
|
elif suffix in [".tif", ".tiff"]:
|
||||||
|
img = Image.open(file_path)
|
||||||
|
img.seek(page_number)
|
||||||
|
else:
|
||||||
|
img = Image.open(file_path)
|
||||||
|
|
||||||
|
return img.crop(
|
||||||
|
(
|
||||||
|
int(left * img.width),
|
||||||
|
int(upper * img.height),
|
||||||
|
int(right * img.width),
|
||||||
|
int(lower * img.height),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AzureAIDocumentIntelligenceLoader(BaseReader):
|
class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||||
|
@ -14,7 +60,7 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||||
heif, docx, xlsx, pptx and html.
|
heif, docx, xlsx, pptx and html.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_dependencies = ["azure-ai-documentintelligence"]
|
_dependencies = ["azure-ai-documentintelligence", "PyMuPDF", "Pillow"]
|
||||||
|
|
||||||
endpoint: str = Param(
|
endpoint: str = Param(
|
||||||
os.environ.get("AZUREAI_DOCUMENT_INTELLIGENT_ENDPOINT", None),
|
os.environ.get("AZUREAI_DOCUMENT_INTELLIGENT_ENDPOINT", None),
|
||||||
|
@ -34,6 +80,29 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||||
"#model-analysis-features)"
|
"#model-analysis-features)"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
output_content_format: str = Param(
|
||||||
|
"markdown",
|
||||||
|
help="Output content format. Can be 'markdown' or 'text'.Default is markdown",
|
||||||
|
)
|
||||||
|
vlm_endpoint: str = Param(
|
||||||
|
help=(
|
||||||
|
"Default VLM endpoint for figure captioning. If not provided, will not "
|
||||||
|
"caption the figures"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
figure_friendly_filetypes: list[str] = Param(
|
||||||
|
[".pdf", ".jpeg", ".jpg", ".png", ".bmp", ".tiff", ".heif", ".tif"],
|
||||||
|
help=(
|
||||||
|
"File types that we can reliably open and extract figures. "
|
||||||
|
"For files like .docx or .html, the visual layout may be different "
|
||||||
|
"when viewed from different tools, hence we cannot use Azure DI "
|
||||||
|
"location to extract figures."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
cache_dir: str = Param(
|
||||||
|
None,
|
||||||
|
help="Directory to cache the downloaded files. Default is None",
|
||||||
|
)
|
||||||
|
|
||||||
@Param.auto(depends_on=["endpoint", "credential"])
|
@Param.auto(depends_on=["endpoint", "credential"])
|
||||||
def client_(self):
|
def client_(self):
|
||||||
|
@ -55,14 +124,114 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||||
def load_data(
|
def load_data(
|
||||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||||
) -> list[Document]:
|
) -> list[Document]:
|
||||||
|
"""Extract the input file, allowing multi-modal extraction"""
|
||||||
metadata = extra_info or {}
|
metadata = extra_info or {}
|
||||||
|
file_name = Path(file_path)
|
||||||
with open(file_path, "rb") as fi:
|
with open(file_path, "rb") as fi:
|
||||||
poller = self.client_.begin_analyze_document(
|
poller = self.client_.begin_analyze_document(
|
||||||
self.model,
|
self.model,
|
||||||
analyze_request=fi,
|
analyze_request=fi,
|
||||||
content_type="application/octet-stream",
|
content_type="application/octet-stream",
|
||||||
output_content_format="markdown",
|
output_content_format=self.output_content_format,
|
||||||
)
|
)
|
||||||
result = poller.result()
|
result = poller.result()
|
||||||
|
|
||||||
return [Document(content=result.content, metadata=metadata)]
|
# the total text content of the document in `output_content_format` format
|
||||||
|
text_content = result.content
|
||||||
|
removed_spans: list[dict] = []
|
||||||
|
|
||||||
|
# extract the figures
|
||||||
|
figures = []
|
||||||
|
for figure_desc in result.get("figures", []):
|
||||||
|
if not self.vlm_endpoint:
|
||||||
|
continue
|
||||||
|
if file_path.suffix.lower() not in self.figure_friendly_filetypes:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# read & crop the image
|
||||||
|
page_number = figure_desc["boundingRegions"][0]["pageNumber"]
|
||||||
|
page_width = result.pages[page_number - 1]["width"]
|
||||||
|
page_height = result.pages[page_number - 1]["height"]
|
||||||
|
polygon = figure_desc["boundingRegions"][0]["polygon"]
|
||||||
|
xs = [polygon[i] for i in range(0, len(polygon), 2)]
|
||||||
|
ys = [polygon[i] for i in range(1, len(polygon), 2)]
|
||||||
|
bbox = [
|
||||||
|
min(xs) / page_width,
|
||||||
|
min(ys) / page_height,
|
||||||
|
max(xs) / page_width,
|
||||||
|
max(ys) / page_height,
|
||||||
|
]
|
||||||
|
img = crop_image(file_path, bbox, page_number - 1)
|
||||||
|
|
||||||
|
# convert the image into base64
|
||||||
|
img_bytes = BytesIO()
|
||||||
|
img.save(img_bytes, format="PNG")
|
||||||
|
img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
|
||||||
|
img_base64 = f"data:image/png;base64,{img_base64}"
|
||||||
|
|
||||||
|
# caption the image
|
||||||
|
caption = generate_single_figure_caption(
|
||||||
|
figure=img_base64, vlm_endpoint=self.vlm_endpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
# store the image into document
|
||||||
|
figure_metadata = {
|
||||||
|
"image_origin": img_base64,
|
||||||
|
"type": "image",
|
||||||
|
"page_label": page_number,
|
||||||
|
}
|
||||||
|
figure_metadata.update(metadata)
|
||||||
|
|
||||||
|
figures.append(
|
||||||
|
Document(
|
||||||
|
text=caption,
|
||||||
|
metadata=figure_metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
removed_spans += figure_desc["spans"]
|
||||||
|
|
||||||
|
# extract the tables
|
||||||
|
tables = []
|
||||||
|
for table_desc in result.get("tables", []):
|
||||||
|
if not table_desc["spans"]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# convert the tables into markdown format
|
||||||
|
boundingRegions = table_desc["boundingRegions"]
|
||||||
|
if boundingRegions:
|
||||||
|
page_number = boundingRegions[0]["pageNumber"]
|
||||||
|
else:
|
||||||
|
page_number = 1
|
||||||
|
|
||||||
|
# store the tables into document
|
||||||
|
offset = table_desc["spans"][0]["offset"]
|
||||||
|
length = table_desc["spans"][0]["length"]
|
||||||
|
table_metadata = {
|
||||||
|
"type": "table",
|
||||||
|
"page_label": page_number,
|
||||||
|
"table_origin": text_content[offset : offset + length],
|
||||||
|
}
|
||||||
|
table_metadata.update(metadata)
|
||||||
|
|
||||||
|
tables.append(
|
||||||
|
Document(
|
||||||
|
text=text_content[offset : offset + length],
|
||||||
|
metadata=table_metadata,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
removed_spans += table_desc["spans"]
|
||||||
|
# save the text content into markdown format
|
||||||
|
if self.cache_dir is not None:
|
||||||
|
with open(
|
||||||
|
Path(self.cache_dir) / f"{file_name.stem}.md", "w", encoding="utf-8"
|
||||||
|
) as f:
|
||||||
|
f.write(text_content)
|
||||||
|
|
||||||
|
removed_spans = sorted(removed_spans, key=lambda x: x["offset"], reverse=True)
|
||||||
|
for span in removed_spans:
|
||||||
|
text_content = (
|
||||||
|
text_content[: span["offset"]]
|
||||||
|
+ text_content[span["offset"] + span["length"] :]
|
||||||
|
)
|
||||||
|
|
||||||
|
return [Document(content=text_content, metadata=metadata)] + figures + tables
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, List, Type, Union
|
||||||
from kotaemon.base import BaseComponent, Document
|
from kotaemon.base import BaseComponent, Document
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from llama_index.readers.base import BaseReader as LIBaseReader
|
from llama_index.core.readers.base import BaseReader as LIBaseReader
|
||||||
|
|
||||||
|
|
||||||
class BaseReader(BaseComponent):
|
class BaseReader(BaseComponent):
|
||||||
|
@ -20,7 +20,7 @@ class AutoReader(BaseReader):
|
||||||
"""Init reader using string identifier or class name from llama-hub"""
|
"""Init reader using string identifier or class name from llama-hub"""
|
||||||
|
|
||||||
if isinstance(reader_type, str):
|
if isinstance(reader_type, str):
|
||||||
from llama_index import download_loader
|
from llama_index.core import download_loader
|
||||||
|
|
||||||
self._reader = download_loader(reader_type)()
|
self._reader = download_loader(reader_type)()
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
from typing import Callable, List, Optional, Type
|
from typing import Callable, List, Optional, Type
|
||||||
|
|
||||||
from llama_index.readers.base import BaseReader as LIBaseReader
|
from llama_index.core.readers.base import BaseReader as LIBaseReader
|
||||||
|
|
||||||
from .base import BaseReader, LIReaderMixin
|
from .base import BaseReader, LIReaderMixin
|
||||||
|
|
||||||
|
@ -48,6 +48,6 @@ class DirectoryReader(LIReaderMixin, BaseReader):
|
||||||
file_metadata: Optional[Callable[[str], dict]] = None
|
file_metadata: Optional[Callable[[str], dict]] = None
|
||||||
|
|
||||||
def _get_wrapped_class(self) -> Type["LIBaseReader"]:
|
def _get_wrapped_class(self) -> Type["LIBaseReader"]:
|
||||||
from llama_index import SimpleDirectoryReader
|
from llama_index.core import SimpleDirectoryReader
|
||||||
|
|
||||||
return SimpleDirectoryReader
|
return SimpleDirectoryReader
|
||||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
@ -27,6 +27,21 @@ class DocxReader(BaseReader):
|
||||||
"Please install it using `pip install python-docx`"
|
"Please install it using `pip install python-docx`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _load_single_table(self, table) -> List[List[str]]:
|
||||||
|
"""Extract content from tables. Return a list of columns: list[str]
|
||||||
|
Some merged cells will share duplicated content.
|
||||||
|
"""
|
||||||
|
n_row = len(table.rows)
|
||||||
|
n_col = len(table.columns)
|
||||||
|
|
||||||
|
arrays = [["" for _ in range(n_row)] for _ in range(n_col)]
|
||||||
|
|
||||||
|
for i, row in enumerate(table.rows):
|
||||||
|
for j, cell in enumerate(row.cells):
|
||||||
|
arrays[j][i] = cell.text
|
||||||
|
|
||||||
|
return arrays
|
||||||
|
|
||||||
def load_data(
|
def load_data(
|
||||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
|
@ -50,13 +65,9 @@ class DocxReader(BaseReader):
|
||||||
|
|
||||||
tables = []
|
tables = []
|
||||||
for t in doc.tables:
|
for t in doc.tables:
|
||||||
arrays = [
|
# return list of columns: list of string
|
||||||
[
|
arrays = self._load_single_table(t)
|
||||||
unicodedata.normalize("NFKC", t.cell(i, j).text)
|
|
||||||
for i in range(len(t.rows))
|
|
||||||
]
|
|
||||||
for j in range(len(t.columns))
|
|
||||||
]
|
|
||||||
tables.append(pd.DataFrame({a[0]: a[1:] for a in arrays}))
|
tables.append(pd.DataFrame({a[0]: a[1:] for a in arrays}))
|
||||||
|
|
||||||
extra_info = extra_info or {}
|
extra_info = extra_info or {}
|
||||||
|
|
|
@ -6,7 +6,7 @@ Pandas parser for .xlsx files.
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
@ -82,6 +82,9 @@ class PandasExcelReader(BaseReader):
|
||||||
sheet = []
|
sheet = []
|
||||||
if include_sheetname:
|
if include_sheetname:
|
||||||
sheet.append([key])
|
sheet.append([key])
|
||||||
|
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||||
|
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||||
|
dfs[key].fillna("", inplace=True)
|
||||||
sheet.extend(dfs[key].values.astype(str).tolist())
|
sheet.extend(dfs[key].values.astype(str).tolist())
|
||||||
df_sheets.append(sheet)
|
df_sheets.append(sheet)
|
||||||
|
|
||||||
|
@ -99,3 +102,91 @@ class PandasExcelReader(BaseReader):
|
||||||
]
|
]
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class ExcelReader(BaseReader):
|
||||||
|
r"""Spreadsheet exporter respecting multiple worksheets
|
||||||
|
|
||||||
|
Parses CSVs using the separator detection from Pandas `read_csv` function.
|
||||||
|
If special parameters are required, use the `pandas_config` dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
|
||||||
|
pandas_config (dict): Options for the `pandas.read_excel` function call.
|
||||||
|
Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html
|
||||||
|
for more information. Set to empty dict by default,
|
||||||
|
this means defaults will be used.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
pandas_config: Optional[dict] = None,
|
||||||
|
row_joiner: str = "\n",
|
||||||
|
col_joiner: str = " ",
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Init params."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._pandas_config = pandas_config or {}
|
||||||
|
self._row_joiner = row_joiner if row_joiner else "\n"
|
||||||
|
self._col_joiner = col_joiner if col_joiner else " "
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
include_sheetname: bool = True,
|
||||||
|
sheet_name: Optional[Union[str, int, list]] = None,
|
||||||
|
extra_info: Optional[dict] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse file and extract values from a specific column.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file (Path): The path to the Excel file to read.
|
||||||
|
include_sheetname (bool): Whether to include the sheet name in the output.
|
||||||
|
sheet_name (Union[str, int, None]): The specific sheet to read from,
|
||||||
|
default is None which reads all sheets.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Document]: A list of`Document objects containing the
|
||||||
|
values from the specified column in the Excel file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
import pandas as pd
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"install pandas using `pip3 install pandas` to use this loader"
|
||||||
|
)
|
||||||
|
|
||||||
|
if sheet_name is not None:
|
||||||
|
sheet_name = (
|
||||||
|
[sheet_name] if not isinstance(sheet_name, list) else sheet_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# clean up input
|
||||||
|
file = Path(file)
|
||||||
|
extra_info = extra_info or {}
|
||||||
|
|
||||||
|
dfs = pd.read_excel(file, sheet_name=sheet_name, **self._pandas_config)
|
||||||
|
sheet_names = dfs.keys()
|
||||||
|
output = []
|
||||||
|
|
||||||
|
for idx, key in enumerate(sheet_names):
|
||||||
|
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||||
|
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||||
|
dfs[key] = dfs[key].astype("object")
|
||||||
|
dfs[key].fillna("", inplace=True)
|
||||||
|
|
||||||
|
rows = dfs[key].values.astype(str).tolist()
|
||||||
|
content = self._row_joiner.join(
|
||||||
|
self._col_joiner.join(row).strip() for row in rows
|
||||||
|
).strip()
|
||||||
|
if include_sheetname:
|
||||||
|
content = f"(Sheet {key} of file {file.name})\n{content}"
|
||||||
|
metadata = {"page_label": idx + 1, "sheet_name": key, **extra_info}
|
||||||
|
output.append(Document(text=content, metadata=metadata))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -2,7 +2,8 @@ import email
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
@ -78,6 +79,9 @@ class MhtmlReader(BaseReader):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
cache_dir: Optional[str] = getattr(
|
||||||
|
flowsettings, "KH_MARKDOWN_OUTPUT_DIR", None
|
||||||
|
),
|
||||||
open_encoding: Optional[str] = None,
|
open_encoding: Optional[str] = None,
|
||||||
bs_kwargs: Optional[dict] = None,
|
bs_kwargs: Optional[dict] = None,
|
||||||
get_text_separator: str = "",
|
get_text_separator: str = "",
|
||||||
|
@ -86,6 +90,7 @@ class MhtmlReader(BaseReader):
|
||||||
to pass to the BeautifulSoup object.
|
to pass to the BeautifulSoup object.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
cache_dir: Path for markdwon format.
|
||||||
file_path: Path to file to load.
|
file_path: Path to file to load.
|
||||||
open_encoding: The encoding to use when opening the file.
|
open_encoding: The encoding to use when opening the file.
|
||||||
bs_kwargs: Any kwargs to pass to the BeautifulSoup object.
|
bs_kwargs: Any kwargs to pass to the BeautifulSoup object.
|
||||||
|
@ -100,6 +105,7 @@ class MhtmlReader(BaseReader):
|
||||||
"`pip install beautifulsoup4`"
|
"`pip install beautifulsoup4`"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.cache_dir = cache_dir
|
||||||
self.open_encoding = open_encoding
|
self.open_encoding = open_encoding
|
||||||
if bs_kwargs is None:
|
if bs_kwargs is None:
|
||||||
bs_kwargs = {"features": "lxml"}
|
bs_kwargs = {"features": "lxml"}
|
||||||
|
@ -116,6 +122,7 @@ class MhtmlReader(BaseReader):
|
||||||
extra_info = extra_info or {}
|
extra_info = extra_info or {}
|
||||||
metadata: dict = extra_info
|
metadata: dict = extra_info
|
||||||
page = []
|
page = []
|
||||||
|
file_name = Path(file_path)
|
||||||
with open(file_path, "r", encoding=self.open_encoding) as f:
|
with open(file_path, "r", encoding=self.open_encoding) as f:
|
||||||
message = email.message_from_string(f.read())
|
message = email.message_from_string(f.read())
|
||||||
parts = message.get_payload()
|
parts = message.get_payload()
|
||||||
|
@ -144,5 +151,11 @@ class MhtmlReader(BaseReader):
|
||||||
text = "\n\n".join(lines)
|
text = "\n\n".join(lines)
|
||||||
if text:
|
if text:
|
||||||
page.append(text)
|
page.append(text)
|
||||||
|
# save the page into markdown format
|
||||||
|
print(self.cache_dir)
|
||||||
|
if self.cache_dir is not None:
|
||||||
|
print(Path(self.cache_dir) / f"{file_name.stem}.md")
|
||||||
|
with open(Path(self.cache_dir) / f"{file_name.stem}.md", "w") as f:
|
||||||
|
f.write(page[0])
|
||||||
|
|
||||||
return [Document(text="\n\n".join(page), metadata=metadata)]
|
return [Document(text="\n\n".join(page), metadata=metadata)]
|
||||||
|
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from langchain.utils import get_from_dict_or_env
|
from langchain.utils import get_from_dict_or_env
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
|
|
@ -5,8 +5,8 @@ from typing import List, Optional
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from tenacity import after_log, retry, stop_after_attempt, wait_fixed, wait_random
|
from tenacity import after_log, retry, stop_after_attempt, wait_exponential
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
@ -19,13 +19,16 @@ DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
|
||||||
|
|
||||||
|
|
||||||
@retry(
|
@retry(
|
||||||
stop=stop_after_attempt(3),
|
stop=stop_after_attempt(6),
|
||||||
wait=wait_fixed(5) + wait_random(0, 2),
|
wait=wait_exponential(multiplier=20, exp_base=2, min=1, max=1000),
|
||||||
after=after_log(logger, logging.DEBUG),
|
after=after_log(logger, logging.WARNING),
|
||||||
)
|
)
|
||||||
def tenacious_api_post(url, **kwargs):
|
def tenacious_api_post(url, file_path, table_only, **kwargs):
|
||||||
resp = requests.post(url=url, **kwargs)
|
with file_path.open("rb") as content:
|
||||||
resp.raise_for_status()
|
files = {"input": content}
|
||||||
|
data = {"job_id": uuid4(), "table_only": table_only}
|
||||||
|
resp = requests.post(url=url, files=files, data=data, **kwargs)
|
||||||
|
resp.raise_for_status()
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,18 +74,16 @@ class OCRReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
file_path = Path(file_path).resolve()
|
file_path = Path(file_path).resolve()
|
||||||
|
|
||||||
with file_path.open("rb") as content:
|
# call the API from FullOCR endpoint
|
||||||
files = {"input": content}
|
if "response_content" in kwargs:
|
||||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
# overriding response content if specified
|
||||||
|
ocr_results = kwargs["response_content"]
|
||||||
# call the API from FullOCR endpoint
|
else:
|
||||||
if "response_content" in kwargs:
|
# call original API
|
||||||
# overriding response content if specified
|
resp = tenacious_api_post(
|
||||||
ocr_results = kwargs["response_content"]
|
url=self.ocr_endpoint, file_path=file_path, table_only=not self.use_ocr
|
||||||
else:
|
)
|
||||||
# call original API
|
ocr_results = resp.json()["result"]
|
||||||
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
|
|
||||||
ocr_results = resp.json()["result"]
|
|
||||||
|
|
||||||
debug_path = kwargs.pop("debug_path", None)
|
debug_path = kwargs.pop("debug_path", None)
|
||||||
artifact_path = kwargs.pop("artifact_path", None)
|
artifact_path = kwargs.pop("artifact_path", None)
|
||||||
|
@ -168,18 +169,16 @@ class ImageReader(BaseReader):
|
||||||
"""
|
"""
|
||||||
file_path = Path(file_path).resolve()
|
file_path = Path(file_path).resolve()
|
||||||
|
|
||||||
with file_path.open("rb") as content:
|
# call the API from FullOCR endpoint
|
||||||
files = {"input": content}
|
if "response_content" in kwargs:
|
||||||
data = {"job_id": uuid4(), "table_only": False}
|
# overriding response content if specified
|
||||||
|
ocr_results = kwargs["response_content"]
|
||||||
# call the API from FullOCR endpoint
|
else:
|
||||||
if "response_content" in kwargs:
|
# call original API
|
||||||
# overriding response content if specified
|
resp = tenacious_api_post(
|
||||||
ocr_results = kwargs["response_content"]
|
url=self.ocr_endpoint, file_path=file_path, table_only=False
|
||||||
else:
|
)
|
||||||
# call original API
|
ocr_results = resp.json()["result"]
|
||||||
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
|
|
||||||
ocr_results = resp.json()["result"]
|
|
||||||
|
|
||||||
extra_info = extra_info or {}
|
extra_info = extra_info or {}
|
||||||
result = []
|
result = []
|
||||||
|
|
114
libs/kotaemon/kotaemon/loaders/pdf_loader.py
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
import base64
|
||||||
|
from io import BytesIO
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
|
from fsspec import AbstractFileSystem
|
||||||
|
from llama_index.readers.file import PDFReader
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
def get_page_thumbnails(
|
||||||
|
file_path: Path, pages: list[int], dpi: int = 80
|
||||||
|
) -> List[Image.Image]:
|
||||||
|
"""Get image thumbnails of the pages in the PDF file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path (Path): path to the image file
|
||||||
|
page_number (list[int]): list of page numbers to extract
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[Image.Image]: list of page thumbnails
|
||||||
|
"""
|
||||||
|
|
||||||
|
img: Image.Image
|
||||||
|
suffix = file_path.suffix.lower()
|
||||||
|
assert suffix == ".pdf", "This function only supports PDF files."
|
||||||
|
try:
|
||||||
|
import fitz
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError("Please install PyMuPDF: 'pip install PyMuPDF'")
|
||||||
|
|
||||||
|
doc = fitz.open(file_path)
|
||||||
|
|
||||||
|
output_imgs = []
|
||||||
|
for page_number in pages:
|
||||||
|
page = doc.load_page(page_number)
|
||||||
|
pm = page.get_pixmap(dpi=dpi)
|
||||||
|
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
|
||||||
|
output_imgs.append(convert_image_to_base64(img))
|
||||||
|
|
||||||
|
return output_imgs
|
||||||
|
|
||||||
|
|
||||||
|
def convert_image_to_base64(img: Image.Image) -> str:
|
||||||
|
# convert the image into base64
|
||||||
|
img_bytes = BytesIO()
|
||||||
|
img.save(img_bytes, format="PNG")
|
||||||
|
img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
|
||||||
|
img_base64 = f"data:image/png;base64,{img_base64}"
|
||||||
|
|
||||||
|
return img_base64
|
||||||
|
|
||||||
|
|
||||||
|
class PDFThumbnailReader(PDFReader):
|
||||||
|
"""PDF parser with thumbnail for each page."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""
|
||||||
|
Initialize PDFReader.
|
||||||
|
"""
|
||||||
|
super().__init__(return_full_document=False)
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self,
|
||||||
|
file: Path,
|
||||||
|
extra_info: Optional[Dict] = None,
|
||||||
|
fs: Optional[AbstractFileSystem] = None,
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Parse file."""
|
||||||
|
documents = super().load_data(file, extra_info, fs)
|
||||||
|
|
||||||
|
page_numbers_str = []
|
||||||
|
filtered_docs = []
|
||||||
|
is_int_page_number: dict[str, bool] = {}
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if "page_label" in doc.metadata:
|
||||||
|
page_num_str = doc.metadata["page_label"]
|
||||||
|
page_numbers_str.append(page_num_str)
|
||||||
|
try:
|
||||||
|
_ = int(page_num_str)
|
||||||
|
is_int_page_number[page_num_str] = True
|
||||||
|
filtered_docs.append(doc)
|
||||||
|
except ValueError:
|
||||||
|
is_int_page_number[page_num_str] = False
|
||||||
|
continue
|
||||||
|
|
||||||
|
documents = filtered_docs
|
||||||
|
page_numbers = list(range(len(page_numbers_str)))
|
||||||
|
|
||||||
|
print("Page numbers:", len(page_numbers))
|
||||||
|
page_thumbnails = get_page_thumbnails(file, page_numbers)
|
||||||
|
|
||||||
|
documents.extend(
|
||||||
|
[
|
||||||
|
Document(
|
||||||
|
text="Page thumbnail",
|
||||||
|
metadata={
|
||||||
|
"image_origin": page_thumbnail,
|
||||||
|
"type": "thumbnail",
|
||||||
|
"page_label": page_number,
|
||||||
|
**(extra_info if extra_info is not None else {}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for (page_thumbnail, page_number) in zip(
|
||||||
|
page_thumbnails, page_numbers_str
|
||||||
|
)
|
||||||
|
if is_int_page_number[page_number]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return documents
|
22
libs/kotaemon/kotaemon/loaders/txt_loader.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
from .base import BaseReader
|
||||||
|
|
||||||
|
|
||||||
|
class TxtReader(BaseReader):
|
||||||
|
def run(
|
||||||
|
self, file_path: str | Path, extra_info: Optional[dict] = None, **kwargs
|
||||||
|
) -> list[Document]:
|
||||||
|
return self.load_data(Path(file_path), extra_info=extra_info, **kwargs)
|
||||||
|
|
||||||
|
def load_data(
|
||||||
|
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||||
|
) -> list[Document]:
|
||||||
|
with open(file_path, "r") as f:
|
||||||
|
text = f.read()
|
||||||
|
|
||||||
|
metadata = extra_info or {}
|
||||||
|
return [Document(text=text, metadata=metadata)]
|
|
@ -12,7 +12,7 @@ pip install xlrd
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from typing import Any, List
|
from typing import Any, List
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from decouple import config
|
from decouple import config
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def generate_gpt4v(
|
def generate_gpt4v(
|
||||||
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
endpoint: str,
|
||||||
|
images: str | List[str],
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
max_images: int = 10,
|
||||||
) -> str:
|
) -> str:
|
||||||
# OpenAI API Key
|
# OpenAI API Key
|
||||||
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||||
|
@ -27,24 +34,36 @@ def generate_gpt4v(
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": image},
|
"image_url": {"url": image},
|
||||||
}
|
}
|
||||||
for image in images
|
for image in images[:max_images]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
|
"temperature": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(images) > max_images:
|
||||||
|
print(f"Truncated to {max_images} images (original {len(images)} images")
|
||||||
|
|
||||||
|
response = requests.post(endpoint, headers=headers, json=payload)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(endpoint, headers=headers, json=payload)
|
response.raise_for_status()
|
||||||
output = response.json()
|
except Exception as e:
|
||||||
output = output["choices"][0]["message"]["content"]
|
logger.exception(f"Error generating gpt4v: {response.text}; error {e}")
|
||||||
except Exception:
|
return ""
|
||||||
output = ""
|
|
||||||
|
output = response.json()
|
||||||
|
output = output["choices"][0]["message"]["content"]
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
def stream_gpt4v(
|
def stream_gpt4v(
|
||||||
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
endpoint: str,
|
||||||
|
images: str | List[str],
|
||||||
|
prompt: str,
|
||||||
|
max_tokens: int = 512,
|
||||||
|
max_images: int = 10,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
# OpenAI API Key
|
# OpenAI API Key
|
||||||
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||||
|
@ -65,17 +84,22 @@ def stream_gpt4v(
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {"url": image},
|
"image_url": {"url": image},
|
||||||
}
|
}
|
||||||
for image in images
|
for image in images[:max_images]
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": max_tokens,
|
"max_tokens": max_tokens,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
|
"logprobs": True,
|
||||||
|
"temperature": 0,
|
||||||
}
|
}
|
||||||
|
if len(images) > max_images:
|
||||||
|
print(f"Truncated to {max_images} images (original {len(images)} images")
|
||||||
try:
|
try:
|
||||||
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
|
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
|
||||||
assert response.status_code == 200, str(response.content)
|
assert response.status_code == 200, str(response.content)
|
||||||
output = ""
|
output = ""
|
||||||
|
logprobs = []
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
if line:
|
if line:
|
||||||
if line.startswith(b"\xef\xbb\xbf"):
|
if line.startswith(b"\xef\xbb\xbf"):
|
||||||
|
@ -89,8 +113,23 @@ def stream_gpt4v(
|
||||||
except Exception:
|
except Exception:
|
||||||
break
|
break
|
||||||
if len(line["choices"]):
|
if len(line["choices"]):
|
||||||
|
if line["choices"][0].get("logprobs") is None:
|
||||||
|
_logprobs = []
|
||||||
|
else:
|
||||||
|
_logprobs = [
|
||||||
|
logprob["logprob"]
|
||||||
|
for logprob in line["choices"][0]["logprobs"].get(
|
||||||
|
"content", []
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
output += line["choices"][0]["delta"].get("content", "")
|
output += line["choices"][0]["delta"].get("content", "")
|
||||||
yield line["choices"][0]["delta"].get("content", "")
|
logprobs += _logprobs
|
||||||
except Exception:
|
yield line["choices"][0]["delta"].get("content", ""), _logprobs
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error streaming gpt4v {e}")
|
||||||
|
logprobs = []
|
||||||
output = ""
|
output = ""
|
||||||
return output
|
|
||||||
|
return output, logprobs
|
||||||
|
|
|
@ -2,12 +2,14 @@ from .docstores import (
|
||||||
BaseDocumentStore,
|
BaseDocumentStore,
|
||||||
ElasticsearchDocumentStore,
|
ElasticsearchDocumentStore,
|
||||||
InMemoryDocumentStore,
|
InMemoryDocumentStore,
|
||||||
|
LanceDBDocumentStore,
|
||||||
SimpleFileDocumentStore,
|
SimpleFileDocumentStore,
|
||||||
)
|
)
|
||||||
from .vectorstores import (
|
from .vectorstores import (
|
||||||
BaseVectorStore,
|
BaseVectorStore,
|
||||||
ChromaVectorStore,
|
ChromaVectorStore,
|
||||||
InMemoryVectorStore,
|
InMemoryVectorStore,
|
||||||
|
LanceDBVectorStore,
|
||||||
SimpleFileVectorStore,
|
SimpleFileVectorStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -17,9 +19,11 @@ __all__ = [
|
||||||
"InMemoryDocumentStore",
|
"InMemoryDocumentStore",
|
||||||
"ElasticsearchDocumentStore",
|
"ElasticsearchDocumentStore",
|
||||||
"SimpleFileDocumentStore",
|
"SimpleFileDocumentStore",
|
||||||
|
"LanceDBDocumentStore",
|
||||||
# Vector stores
|
# Vector stores
|
||||||
"BaseVectorStore",
|
"BaseVectorStore",
|
||||||
"ChromaVectorStore",
|
"ChromaVectorStore",
|
||||||
"InMemoryVectorStore",
|
"InMemoryVectorStore",
|
||||||
"SimpleFileVectorStore",
|
"SimpleFileVectorStore",
|
||||||
|
"LanceDBVectorStore",
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from .base import BaseDocumentStore
|
from .base import BaseDocumentStore
|
||||||
from .elasticsearch import ElasticsearchDocumentStore
|
from .elasticsearch import ElasticsearchDocumentStore
|
||||||
from .in_memory import InMemoryDocumentStore
|
from .in_memory import InMemoryDocumentStore
|
||||||
|
from .lancedb import LanceDBDocumentStore
|
||||||
from .simple_file import SimpleFileDocumentStore
|
from .simple_file import SimpleFileDocumentStore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -8,4 +9,5 @@ __all__ = [
|
||||||
"InMemoryDocumentStore",
|
"InMemoryDocumentStore",
|
||||||
"ElasticsearchDocumentStore",
|
"ElasticsearchDocumentStore",
|
||||||
"SimpleFileDocumentStore",
|
"SimpleFileDocumentStore",
|
||||||
|
"LanceDBDocumentStore",
|
||||||
]
|
]
|
||||||
|
|
|
@ -41,6 +41,13 @@ class BaseDocumentStore(ABC):
|
||||||
"""Count number of documents"""
|
"""Count number of documents"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(
|
||||||
|
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Search document store using search query"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def delete(self, ids: Union[List[str], str]):
|
def delete(self, ids: Union[List[str], str]):
|
||||||
"""Delete document by id"""
|
"""Delete document by id"""
|
||||||
|
|
|
@ -92,7 +92,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||||
"_id": doc_id,
|
"_id": doc_id,
|
||||||
}
|
}
|
||||||
requests.append(request)
|
requests.append(request)
|
||||||
self.es_bulk(self.client, requests)
|
|
||||||
|
success, failed = self.es_bulk(self.client, requests)
|
||||||
|
print("Added/Updated documents to index", success)
|
||||||
|
print("Failed documents to index", failed)
|
||||||
|
|
||||||
if refresh_indices:
|
if refresh_indices:
|
||||||
self.client.indices.refresh(index=self.index_name)
|
self.client.indices.refresh(index=self.index_name)
|
||||||
|
@ -131,16 +134,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||||
Returns:
|
Returns:
|
||||||
List[Document]: List of result documents
|
List[Document]: List of result documents
|
||||||
"""
|
"""
|
||||||
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
|
query_dict: dict = {"match": {"content": query}}
|
||||||
if doc_ids:
|
if doc_ids is not None:
|
||||||
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
|
query_dict = {"bool": {"must": [query_dict, {"terms": {"_id": doc_ids}}]}}
|
||||||
|
query_dict = {"query": query_dict, "size": top_k}
|
||||||
return self.query_raw(query_dict)
|
return self.query_raw(query_dict)
|
||||||
|
|
||||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||||
"""Get document by id"""
|
"""Get document by id"""
|
||||||
if not isinstance(ids, list):
|
if not isinstance(ids, list):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
query_dict = {"query": {"terms": {"_id": ids}}}
|
query_dict = {"query": {"terms": {"_id": ids}}, "size": 10000}
|
||||||
return self.query_raw(query_dict)
|
return self.query_raw(query_dict)
|
||||||
|
|
||||||
def count(self) -> int:
|
def count(self) -> int:
|
||||||
|
|
|
@ -81,6 +81,12 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||||
# Also, for portability, use SQLAlchemy for document store.
|
# Also, for portability, use SQLAlchemy for document store.
|
||||||
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
"""Perform full-text search on document store"""
|
||||||
|
return []
|
||||||
|
|
||||||
def __persist_flow__(self):
|
def __persist_flow__(self):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
153
libs/kotaemon/kotaemon/storages/docstores/lancedb.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
import json
|
||||||
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
|
from kotaemon.base import Document
|
||||||
|
|
||||||
|
from .base import BaseDocumentStore
|
||||||
|
|
||||||
|
MAX_DOCS_TO_GET = 10**4
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBDocumentStore(BaseDocumentStore):
|
||||||
|
"""LancdDB document store which support full-text search query"""
|
||||||
|
|
||||||
|
def __init__(self, path: str = "lancedb", collection_name: str = "docstore"):
|
||||||
|
try:
|
||||||
|
import lancedb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install lancedb: 'pip install lancedb tanvity-py'"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.db_uri = path
|
||||||
|
self.collection_name = collection_name
|
||||||
|
self.db_connection = lancedb.connect(self.db_uri) # type: ignore
|
||||||
|
|
||||||
|
def add(
|
||||||
|
self,
|
||||||
|
docs: Union[Document, List[Document]],
|
||||||
|
ids: Optional[Union[List[str], str]] = None,
|
||||||
|
refresh_indices: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""Load documents into lancedb storage."""
|
||||||
|
doc_ids = ids if ids else [doc.doc_id for doc in docs]
|
||||||
|
data: list[dict[str, str]] | None = [
|
||||||
|
{
|
||||||
|
"id": doc_id,
|
||||||
|
"text": doc.text,
|
||||||
|
"attributes": json.dumps(doc.metadata),
|
||||||
|
}
|
||||||
|
for doc_id, doc in zip(doc_ids, docs)
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.collection_name not in self.db_connection.table_names():
|
||||||
|
if data:
|
||||||
|
document_collection = self.db_connection.create_table(
|
||||||
|
self.collection_name, data=data, mode="overwrite"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# add data to existing table
|
||||||
|
document_collection = self.db_connection.open_table(self.collection_name)
|
||||||
|
if data:
|
||||||
|
document_collection.add(data)
|
||||||
|
|
||||||
|
if refresh_indices:
|
||||||
|
document_collection.create_fts_index(
|
||||||
|
"text",
|
||||||
|
tokenizer_name="en_stem",
|
||||||
|
replace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def query(
|
||||||
|
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||||
|
) -> List[Document]:
|
||||||
|
if doc_ids:
|
||||||
|
id_filter = ", ".join([f"'{_id}'" for _id in doc_ids])
|
||||||
|
query_filter = f"id in ({id_filter})"
|
||||||
|
else:
|
||||||
|
query_filter = None
|
||||||
|
try:
|
||||||
|
document_collection = self.db_connection.open_table(self.collection_name)
|
||||||
|
if query_filter:
|
||||||
|
docs = (
|
||||||
|
document_collection.search(query, query_type="fts")
|
||||||
|
.where(query_filter, prefilter=True)
|
||||||
|
.limit(top_k)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
docs = (
|
||||||
|
document_collection.search(query, query_type="fts")
|
||||||
|
.limit(top_k)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
except (ValueError, FileNotFoundError):
|
||||||
|
docs = []
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
id_=doc["id"],
|
||||||
|
text=doc["text"] if doc["text"] else "<empty>",
|
||||||
|
metadata=json.loads(doc["attributes"]),
|
||||||
|
)
|
||||||
|
for doc in docs
|
||||||
|
]
|
||||||
|
|
||||||
|
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||||
|
"""Get document by id"""
|
||||||
|
if not isinstance(ids, list):
|
||||||
|
ids = [ids]
|
||||||
|
|
||||||
|
id_filter = ", ".join([f"'{_id}'" for _id in ids])
|
||||||
|
try:
|
||||||
|
document_collection = self.db_connection.open_table(self.collection_name)
|
||||||
|
query_filter = f"id in ({id_filter})"
|
||||||
|
docs = (
|
||||||
|
document_collection.search()
|
||||||
|
.where(query_filter)
|
||||||
|
.limit(MAX_DOCS_TO_GET)
|
||||||
|
.to_list()
|
||||||
|
)
|
||||||
|
except (ValueError, FileNotFoundError):
|
||||||
|
docs = []
|
||||||
|
return [
|
||||||
|
Document(
|
||||||
|
id_=doc["id"],
|
||||||
|
text=doc["text"] if doc["text"] else "<empty>",
|
||||||
|
metadata=json.loads(doc["attributes"]),
|
||||||
|
)
|
||||||
|
for doc in docs
|
||||||
|
]
|
||||||
|
|
||||||
|
def delete(self, ids: Union[List[str], str], refresh_indices: bool = True):
|
||||||
|
"""Delete document by id"""
|
||||||
|
if not isinstance(ids, list):
|
||||||
|
ids = [ids]
|
||||||
|
|
||||||
|
document_collection = self.db_connection.open_table(self.collection_name)
|
||||||
|
id_filter = ", ".join([f"'{_id}'" for _id in ids])
|
||||||
|
query_filter = f"id in ({id_filter})"
|
||||||
|
document_collection.delete(query_filter)
|
||||||
|
|
||||||
|
if refresh_indices:
|
||||||
|
document_collection.create_fts_index(
|
||||||
|
"text",
|
||||||
|
tokenizer_name="en_stem",
|
||||||
|
replace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def drop(self):
|
||||||
|
"""Drop the document store"""
|
||||||
|
self.db_connection.drop_table(self.collection_name)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_all(self) -> List[Document]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __persist_flow__(self):
|
||||||
|
return {
|
||||||
|
"db_uri": self.db_uri,
|
||||||
|
"collection_name": self.collection_name,
|
||||||
|
}
|
|
@ -1,6 +1,7 @@
|
||||||
from .base import BaseVectorStore
|
from .base import BaseVectorStore
|
||||||
from .chroma import ChromaVectorStore
|
from .chroma import ChromaVectorStore
|
||||||
from .in_memory import InMemoryVectorStore
|
from .in_memory import InMemoryVectorStore
|
||||||
|
from .lancedb import LanceDBVectorStore
|
||||||
from .simple_file import SimpleFileVectorStore
|
from .simple_file import SimpleFileVectorStore
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -8,4 +9,5 @@ __all__ = [
|
||||||
"ChromaVectorStore",
|
"ChromaVectorStore",
|
||||||
"InMemoryVectorStore",
|
"InMemoryVectorStore",
|
||||||
"SimpleFileVectorStore",
|
"SimpleFileVectorStore",
|
||||||
|
"LanceDBVectorStore",
|
||||||
]
|
]
|
||||||
|
|
|
@ -3,10 +3,10 @@ from __future__ import annotations
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from llama_index.schema import NodeRelationship, RelatedNodeInfo
|
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo
|
||||||
from llama_index.vector_stores.types import BasePydanticVectorStore
|
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||||
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
from llama_index.core.vector_stores.types import VectorStore as LIVectorStore
|
||||||
from llama_index.vector_stores.types import VectorStoreQuery
|
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||||
|
|
||||||
from kotaemon.base import DocumentWithEmbedding
|
from kotaemon.base import DocumentWithEmbedding
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
from llama_index.core.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||||
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
from llama_index.core.vector_stores.simple import SimpleVectorStoreData
|
||||||
|
|
||||||
from .base import LlamaIndexVectorStore
|
from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
|
87
libs/kotaemon/kotaemon/storages/vectorstores/lancedb.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
from typing import Any, List, Type, cast
|
||||||
|
|
||||||
|
from llama_index.core.vector_stores.types import MetadataFilters
|
||||||
|
from llama_index.vector_stores.lancedb import LanceDBVectorStore as LILanceDBVectorStore
|
||||||
|
from llama_index.vector_stores.lancedb import base as base_lancedb
|
||||||
|
|
||||||
|
from .base import LlamaIndexVectorStore
|
||||||
|
|
||||||
|
# custom monkey patch for LanceDB
|
||||||
|
original_to_lance_filter = base_lancedb._to_lance_filter
|
||||||
|
|
||||||
|
|
||||||
|
def custom_to_lance_filter(
|
||||||
|
standard_filters: MetadataFilters, metadata_keys: list
|
||||||
|
) -> Any:
|
||||||
|
for filter in standard_filters.filters:
|
||||||
|
if isinstance(filter.value, list):
|
||||||
|
# quote string values if filter are list of strings
|
||||||
|
if filter.value and isinstance(filter.value[0], str):
|
||||||
|
filter.value = [f"'{v}'" for v in filter.value]
|
||||||
|
|
||||||
|
return original_to_lance_filter(standard_filters, metadata_keys)
|
||||||
|
|
||||||
|
|
||||||
|
# skip table existence check
|
||||||
|
LILanceDBVectorStore._table_exists = lambda _: False
|
||||||
|
base_lancedb._to_lance_filter = custom_to_lance_filter
|
||||||
|
|
||||||
|
|
||||||
|
class LanceDBVectorStore(LlamaIndexVectorStore):
|
||||||
|
_li_class: Type[LILanceDBVectorStore] = LILanceDBVectorStore
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
path: str = "./lancedb",
|
||||||
|
collection_name: str = "default",
|
||||||
|
**kwargs: Any,
|
||||||
|
):
|
||||||
|
self._path = path
|
||||||
|
self._collection_name = collection_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lancedb
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Please install lancedb: 'pip install lancedb tanvity-py'"
|
||||||
|
)
|
||||||
|
|
||||||
|
db_connection = lancedb.connect(path) # type: ignore
|
||||||
|
try:
|
||||||
|
table = db_connection.open_table(collection_name)
|
||||||
|
except FileNotFoundError:
|
||||||
|
table = None
|
||||||
|
|
||||||
|
self._kwargs = kwargs
|
||||||
|
|
||||||
|
# pass through for nice IDE support
|
||||||
|
super().__init__(
|
||||||
|
uri=path,
|
||||||
|
table_name=collection_name,
|
||||||
|
table=table,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
self._client = cast(LILanceDBVectorStore, self._client)
|
||||||
|
self._client._metadata_keys = ["file_id"]
|
||||||
|
|
||||||
|
def delete(self, ids: List[str], **kwargs):
|
||||||
|
"""Delete vector embeddings from vector stores
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ids: List of ids of the embeddings to be deleted
|
||||||
|
kwargs: meant for vectorstore-specific parameters
|
||||||
|
"""
|
||||||
|
self._client.delete_nodes(ids)
|
||||||
|
|
||||||
|
def drop(self):
|
||||||
|
"""Delete entire collection from vector stores"""
|
||||||
|
self._client.client.drop_table(self.collection_name)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __persist_flow__(self):
|
||||||
|
return {
|
||||||
|
"path": self._path,
|
||||||
|
"collection_name": self._collection_name,
|
||||||
|
}
|
|
@ -3,8 +3,8 @@ from pathlib import Path
|
||||||
from typing import Any, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
|
|
||||||
import fsspec
|
import fsspec
|
||||||
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
from llama_index.core.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||||
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
from llama_index.core.vector_stores.simple import SimpleVectorStoreData
|
||||||
|
|
||||||
from kotaemon.base import DocumentWithEmbedding
|
from kotaemon.base import DocumentWithEmbedding
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,11 @@ dependencies = [
|
||||||
"langchain-openai>=0.1.4,<0.2.0",
|
"langchain-openai>=0.1.4,<0.2.0",
|
||||||
"openai>=1.23.6,<2",
|
"openai>=1.23.6,<2",
|
||||||
"theflow>=0.8.6,<0.9.0",
|
"theflow>=0.8.6,<0.9.0",
|
||||||
"llama-index==0.9.48",
|
"llama-index>=0.10.40,<0.11.0",
|
||||||
|
"llama-index-vector-stores-chroma>=0.1.9",
|
||||||
|
"llama-index-vector-stores-lancedb",
|
||||||
"llama-hub>=0.0.79,<0.1.0",
|
"llama-hub>=0.0.79,<0.1.0",
|
||||||
"gradio>=4.26.0,<5",
|
"gradio>=4.31.0,<4.40",
|
||||||
"openpyxl>=3.1.2,<3.2",
|
"openpyxl>=3.1.2,<3.2",
|
||||||
"cookiecutter>=2.6.0,<2.7",
|
"cookiecutter>=2.6.0,<2.7",
|
||||||
"click>=8.1.7,<9",
|
"click>=8.1.7,<9",
|
||||||
|
@ -36,13 +38,9 @@ dependencies = [
|
||||||
"trogon>=0.5.0,<0.6",
|
"trogon>=0.5.0,<0.6",
|
||||||
"tenacity>=8.2.3,<8.3",
|
"tenacity>=8.2.3,<8.3",
|
||||||
"python-dotenv>=1.0.1,<1.1",
|
"python-dotenv>=1.0.1,<1.1",
|
||||||
"chromadb>=0.4.21,<0.5",
|
|
||||||
"unstructured==0.13.4",
|
|
||||||
"pypdf>=4.2.0,<4.3",
|
"pypdf>=4.2.0,<4.3",
|
||||||
|
"PyMuPDF>=1.23",
|
||||||
"html2text==2024.2.26",
|
"html2text==2024.2.26",
|
||||||
"fastembed==0.2.6",
|
|
||||||
"llama-cpp-python>=0.2.72,<0.3",
|
|
||||||
"azure-ai-documentintelligence",
|
|
||||||
"cohere>=5.3.2,<5.4",
|
"cohere>=5.3.2,<5.4",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
@ -63,11 +61,12 @@ adv = [
|
||||||
"duckduckgo-search>=6.1.0,<6.2",
|
"duckduckgo-search>=6.1.0,<6.2",
|
||||||
"googlesearch-python>=1.2.4,<1.3",
|
"googlesearch-python>=1.2.4,<1.3",
|
||||||
"python-docx>=1.1.0,<1.2",
|
"python-docx>=1.1.0,<1.2",
|
||||||
"unstructured[pdf]==0.13.4",
|
|
||||||
"sentence_transformers==2.7.0",
|
|
||||||
"elasticsearch>=8.13.0,<8.14",
|
"elasticsearch>=8.13.0,<8.14",
|
||||||
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
|
|
||||||
"beautifulsoup4>=4.12.3,<4.13",
|
"beautifulsoup4>=4.12.3,<4.13",
|
||||||
|
"plotly",
|
||||||
|
"tabulate",
|
||||||
|
"fast_langdetect",
|
||||||
|
"azure-ai-documentintelligence",
|
||||||
]
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"ipython",
|
"ipython",
|
||||||
|
|
|
@ -2,7 +2,7 @@ from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from langchain.schema import Document as LangchainDocument
|
from langchain.schema import Document as LangchainDocument
|
||||||
from llama_index.node_parser import SimpleNodeParser
|
from llama_index.core.node_parser import SimpleNodeParser
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
from kotaemon.loaders import (
|
from kotaemon.loaders import (
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from llama_index.schema import NodeRelationship
|
from llama_index.core.schema import NodeRelationship
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
from kotaemon.indices.splitters import TokenSplitter
|
from kotaemon.indices.splitters import TokenSplitter
|
||||||
|
|
1
libs/ktem/.gitignore
vendored
|
@ -1,2 +1,3 @@
|
||||||
14-1_抜粋-1.pdf
|
14-1_抜粋-1.pdf
|
||||||
_example_.db
|
_example_.db
|
||||||
|
ktem/assets/prebuilt/
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import Optional
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import pluggy
|
import pluggy
|
||||||
from ktem import extension_protocol
|
from ktem import extension_protocol
|
||||||
|
from ktem.assets import PDFJS_PREBUILT_DIR
|
||||||
from ktem.components import reasonings
|
from ktem.components import reasonings
|
||||||
from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared
|
from ktem.exceptions import HookAlreadyDeclared, HookNotDeclared
|
||||||
from ktem.index import IndexManager
|
from ktem.index import IndexManager
|
||||||
|
@ -36,6 +37,7 @@ class BaseApp:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.dev_mode = getattr(settings, "KH_MODE", "") == "dev"
|
self.dev_mode = getattr(settings, "KH_MODE", "") == "dev"
|
||||||
self.app_name = getattr(settings, "KH_APP_NAME", "Kotaemon")
|
self.app_name = getattr(settings, "KH_APP_NAME", "Kotaemon")
|
||||||
|
self.app_version = getattr(settings, "KH_APP_VERSION", "")
|
||||||
self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False)
|
self.f_user_management = getattr(settings, "KH_FEATURE_USER_MANAGEMENT", False)
|
||||||
self._theme = gr.Theme.from_hub("lone17/kotaemon")
|
self._theme = gr.Theme.from_hub("lone17/kotaemon")
|
||||||
|
|
||||||
|
@ -44,6 +46,13 @@ class BaseApp:
|
||||||
self._css = fi.read()
|
self._css = fi.read()
|
||||||
with (dir_assets / "js" / "main.js").open() as fi:
|
with (dir_assets / "js" / "main.js").open() as fi:
|
||||||
self._js = fi.read()
|
self._js = fi.read()
|
||||||
|
self._js = self._js.replace("KH_APP_VERSION", self.app_version)
|
||||||
|
with (dir_assets / "js" / "pdf_viewer.js").open() as fi:
|
||||||
|
self._pdf_view_js = fi.read()
|
||||||
|
self._pdf_view_js = self._pdf_view_js.replace(
|
||||||
|
"PDFJS_PREBUILT_DIR", str(PDFJS_PREBUILT_DIR)
|
||||||
|
)
|
||||||
|
|
||||||
self._favicon = str(dir_assets / "img" / "favicon.svg")
|
self._favicon = str(dir_assets / "img" / "favicon.svg")
|
||||||
|
|
||||||
self.default_settings = SettingGroup(
|
self.default_settings = SettingGroup(
|
||||||
|
@ -156,11 +165,17 @@ class BaseApp:
|
||||||
"""Called when the app is created"""
|
"""Called when the app is created"""
|
||||||
|
|
||||||
def make(self):
|
def make(self):
|
||||||
|
external_js = """
|
||||||
|
<script type="module" src="https://cdn.skypack.dev/pdfjs-viewer-element"></script>
|
||||||
|
"""
|
||||||
|
|
||||||
with gr.Blocks(
|
with gr.Blocks(
|
||||||
theme=self._theme,
|
theme=self._theme,
|
||||||
css=self._css,
|
css=self._css,
|
||||||
title=self.app_name,
|
title=self.app_name,
|
||||||
analytics_enabled=False,
|
analytics_enabled=False,
|
||||||
|
js=self._js,
|
||||||
|
head=external_js,
|
||||||
) as demo:
|
) as demo:
|
||||||
self.app = demo
|
self.app = demo
|
||||||
self.settings_state.render()
|
self.settings_state.render()
|
||||||
|
@ -173,6 +188,8 @@ class BaseApp:
|
||||||
self.register_events()
|
self.register_events()
|
||||||
self.on_app_created()
|
self.on_app_created()
|
||||||
|
|
||||||
|
demo.load(None, None, None, js=self._pdf_view_js)
|
||||||
|
|
||||||
return demo
|
return demo
|
||||||
|
|
||||||
def declare_public_events(self):
|
def declare_public_events(self):
|
||||||
|
@ -200,7 +217,6 @@ class BaseApp:
|
||||||
|
|
||||||
def on_app_created(self):
|
def on_app_created(self):
|
||||||
"""Execute on app created callbacks"""
|
"""Execute on app created callbacks"""
|
||||||
self.app.load(lambda: None, None, None, js=f"() => {{{self._js}}}")
|
|
||||||
self._on_app_created()
|
self._on_app_created()
|
||||||
for value in self.__dict__.values():
|
for value in self.__dict__.values():
|
||||||
if isinstance(value, BasePage):
|
if isinstance(value, BasePage):
|
||||||
|
|
6
libs/ktem/ktem/assets/__init__.py
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from decouple import config
|
||||||
|
|
||||||
|
PDFJS_VERSION_DIST: str = config("PDFJS_VERSION_DIST", "pdfjs-4.0.379-dist")
|
||||||
|
PDFJS_PREBUILT_DIR: Path = Path(__file__).parent / "prebuilt" / PDFJS_VERSION_DIST
|
|
@ -147,6 +147,16 @@ mark {
|
||||||
max-height: 42px;
|
max-height: 42px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Hide sort buttons at gr.DataFrame */
|
||||||
|
.sort-button {
|
||||||
|
display: none !important;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Show sort button only in File list*/
|
||||||
|
#file_list_view .sort-button {
|
||||||
|
display: block !important;
|
||||||
|
}
|
||||||
|
|
||||||
.scrollable {
|
.scrollable {
|
||||||
overflow-y: auto;
|
overflow-y: auto;
|
||||||
}
|
}
|
||||||
|
@ -158,3 +168,58 @@ mark {
|
||||||
.unset-overflow {
|
.unset-overflow {
|
||||||
overflow: unset !important;
|
overflow: unset !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*body {*/
|
||||||
|
/* margin: 0;*/
|
||||||
|
/* font-family: Arial, sans-serif;*/
|
||||||
|
/*}*/
|
||||||
|
|
||||||
|
pdfjs-viewer-element {
|
||||||
|
height: 100vh;
|
||||||
|
height: 100dvh;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* Modal styles */
|
||||||
|
|
||||||
|
.modal {
|
||||||
|
display: none;
|
||||||
|
position: relative;
|
||||||
|
z-index: 1;
|
||||||
|
left: 0;
|
||||||
|
top: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
overflow: auto;
|
||||||
|
background-color: rgb(0, 0, 0);
|
||||||
|
background-color: rgba(0, 0, 0, 0.4);
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-header {
|
||||||
|
padding: 0px 10px
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-content {
|
||||||
|
background-color: #fefefe;
|
||||||
|
height: 110%;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
}
|
||||||
|
|
||||||
|
.close {
|
||||||
|
color: #aaa;
|
||||||
|
align-self: flex-end;
|
||||||
|
font-size: 28px;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
.close:hover,
|
||||||
|
.close:focus {
|
||||||
|
color: black;
|
||||||
|
text-decoration: none;
|
||||||
|
cursor: pointer;
|
||||||
|
}
|
||||||
|
|
||||||
|
.modal-body {
|
||||||
|
flex: 1;
|
||||||
|
overflow: auto;
|
||||||
|
}
|
||||||
|
|
1
libs/ktem/ktem/assets/icons/delete.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="h-5 w-5 shrink-0"><path fill="#f93a37" fill-rule="evenodd" d="M10.556 4a1 1 0 0 0-.97.751l-.292 1.14h5.421l-.293-1.14A1 1 0 0 0 13.453 4zm6.224 1.892-.421-1.639A3 3 0 0 0 13.453 2h-2.897A3 3 0 0 0 7.65 4.253l-.421 1.639H4a1 1 0 1 0 0 2h.1l1.215 11.425A3 3 0 0 0 8.3 22h7.4a3 3 0 0 0 2.984-2.683l1.214-11.425H20a1 1 0 1 0 0-2zm1.108 2H6.112l1.192 11.214A1 1 0 0 0 8.3 20h7.4a1 1 0 0 0 .995-.894zM10 10a1 1 0 0 1 1 1v5a1 1 0 1 1-2 0v-5a1 1 0 0 1 1-1m4 0a1 1 0 0 1 1 1v5a1 1 0 1 1-2 0v-5a1 1 0 0 1 1-1" clip-rule="evenodd"/></svg>
|
After Width: | Height: | Size: 610 B |
1
libs/ktem/ktem/assets/icons/new.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="#10b981" class="icon-xl-heavy"><path d="M15.673 3.913a3.121 3.121 0 1 1 4.414 4.414l-5.937 5.937a5 5 0 0 1-2.828 1.415l-2.18.31a1 1 0 0 1-1.132-1.13l.311-2.18A5 5 0 0 1 9.736 9.85zm3 1.414a1.12 1.12 0 0 0-1.586 0l-5.937 5.937a3 3 0 0 0-.849 1.697l-.123.86.86-.122a3 3 0 0 0 1.698-.849l5.937-5.937a1.12 1.12 0 0 0 0-1.586M11 4a1 1 0 0 1-1 1c-.998 0-1.702.008-2.253.06-.54.052-.862.141-1.109.267a3 3 0 0 0-1.311 1.311c-.134.263-.226.611-.276 1.216C5.001 8.471 5 9.264 5 10.4v3.2c0 1.137 0 1.929.051 2.546.05.605.142.953.276 1.216a3 3 0 0 0 1.311 1.311c.263.134.611.226 1.216.276.617.05 1.41.051 2.546.051h3.2c1.137 0 1.929 0 2.546-.051.605-.05.953-.142 1.216-.276a3 3 0 0 0 1.311-1.311c.126-.247.215-.569.266-1.108.053-.552.06-1.256.06-2.255a1 1 0 1 1 2 .002c0 .978-.006 1.78-.069 2.442-.064.673-.192 1.27-.475 1.827a5 5 0 0 1-2.185 2.185c-.592.302-1.232.428-1.961.487C15.6 21 14.727 21 13.643 21h-3.286c-1.084 0-1.958 0-2.666-.058-.728-.06-1.369-.185-1.96-.487a5 5 0 0 1-2.186-2.185c-.302-.592-.428-1.233-.487-1.961C3 15.6 3 14.727 3 13.643v-3.286c0-1.084 0-1.958.058-2.666.06-.729.185-1.369.487-1.961A5 5 0 0 1 5.73 3.545c.556-.284 1.154-.411 1.827-.475C8.22 3.007 9.021 3 10 3a1 1 0 0 1 1 1"/></svg>
|
After Width: | Height: | Size: 1.2 KiB |
1
libs/ktem/ktem/assets/icons/rename.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="h-5 w-5 shrink-0"><path fill="#cecece" fill-rule="evenodd" d="M13.293 4.293a4.536 4.536 0 1 1 6.414 6.414l-1 1-7.094 7.094A5 5 0 0 1 8.9 20.197l-4.736.79a1 1 0 0 1-1.15-1.151l.789-4.736a5 5 0 0 1 1.396-2.713zM13 7.414l-6.386 6.387a3 3 0 0 0-.838 1.628l-.56 3.355 3.355-.56a3 3 0 0 0 1.628-.837L16.586 11zm5 2.172L14.414 6l.293-.293a2.536 2.536 0 0 1 3.586 3.586z" clip-rule="evenodd"/></svg>
|
After Width: | Height: | Size: 474 B |
1
libs/ktem/ktem/assets/icons/sidebar.svg
Normal file
|
@ -0,0 +1 @@
|
||||||
|
<svg xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" class="icon-xl-heavy"><path fill="#cecece" fill-rule="evenodd" d="M8.857 3h6.286c1.084 0 1.958 0 2.666.058.729.06 1.369.185 1.961.487a5 5 0 0 1 2.185 2.185c.302.592.428 1.233.487 1.961.058.708.058 1.582.058 2.666v3.286c0 1.084 0 1.958-.058 2.666-.06.729-.185 1.369-.487 1.961a5 5 0 0 1-2.185 2.185c-.592.302-1.232.428-1.961.487C17.1 21 16.227 21 15.143 21H8.857c-1.084 0-1.958 0-2.666-.058-.728-.06-1.369-.185-1.96-.487a5 5 0 0 1-2.186-2.185c-.302-.592-.428-1.232-.487-1.961C1.5 15.6 1.5 14.727 1.5 13.643v-3.286c0-1.084 0-1.958.058-2.666.06-.728.185-1.369.487-1.96A5 5 0 0 1 4.23 3.544c.592-.302 1.233-.428 1.961-.487C6.9 3 7.773 3 8.857 3M6.354 5.051c-.605.05-.953.142-1.216.276a3 3 0 0 0-1.311 1.311c-.134.263-.226.611-.276 1.216-.05.617-.051 1.41-.051 2.546v3.2c0 1.137 0 1.929.051 2.546.05.605.142.953.276 1.216a3 3 0 0 0 1.311 1.311c.263.134.611.226 1.216.276.617.05 1.41.051 2.546.051h.6V5h-.6c-1.137 0-1.929 0-2.546.051M11.5 5v14h3.6c1.137 0 1.929 0 2.546-.051.605-.05.953-.142 1.216-.276a3 3 0 0 0 1.311-1.311c.134-.263.226-.611.276-1.216.05-.617.051-1.41.051-2.546v-3.2c0-1.137 0-1.929-.051-2.546-.05-.605-.142-.953-.276-1.216a3 3 0 0 0-1.311-1.311c-.263-.134-.611-.226-1.216-.276C17.029 5.001 16.236 5 15.1 5zM5 8.5a1 1 0 0 1 1-1h1a1 1 0 1 1 0 2H6a1 1 0 0 1-1-1M5 12a1 1 0 0 1 1-1h1a1 1 0 1 1 0 2H6a1 1 0 0 1-1-1" clip-rule="evenodd"/></svg>
|
After Width: | Height: | Size: 1.4 KiB |
|
@ -1,30 +1,37 @@
|
||||||
let main_parent = document.getElementById("chat-tab").parentNode;
|
function run() {
|
||||||
|
let main_parent = document.getElementById("chat-tab").parentNode;
|
||||||
|
|
||||||
main_parent.childNodes[0].classList.add("header-bar");
|
main_parent.childNodes[0].classList.add("header-bar");
|
||||||
main_parent.style = "padding: 0; margin: 0";
|
main_parent.style = "padding: 0; margin: 0";
|
||||||
main_parent.parentNode.style = "gap: 0";
|
main_parent.parentNode.style = "gap: 0";
|
||||||
main_parent.parentNode.parentNode.style = "padding: 0";
|
main_parent.parentNode.parentNode.style = "padding: 0";
|
||||||
|
|
||||||
|
const version_node = document.createElement("p");
|
||||||
|
version_node.innerHTML = "version: KH_APP_VERSION";
|
||||||
|
version_node.style = "position: fixed; top: 10px; right: 10px;";
|
||||||
|
main_parent.appendChild(version_node);
|
||||||
|
|
||||||
// clpse
|
// clpse
|
||||||
globalThis.clpseFn = (id) => {
|
globalThis.clpseFn = (id) => {
|
||||||
var obj = document.getElementById('clpse-btn-' + id);
|
var obj = document.getElementById('clpse-btn-' + id);
|
||||||
obj.classList.toggle("clpse-active");
|
obj.classList.toggle("clpse-active");
|
||||||
var content = obj.nextElementSibling;
|
var content = obj.nextElementSibling;
|
||||||
if (content.style.display === "none") {
|
if (content.style.display === "none") {
|
||||||
content.style.display = "block";
|
content.style.display = "block";
|
||||||
} else {
|
} else {
|
||||||
content.style.display = "none";
|
content.style.display = "none";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// store info in local storage
|
||||||
|
globalThis.setStorage = (key, value) => {
|
||||||
|
localStorage.setItem(key, value)
|
||||||
|
}
|
||||||
|
globalThis.getStorage = (key, value) => {
|
||||||
|
item = localStorage.getItem(key);
|
||||||
|
return item ? item : value;
|
||||||
|
}
|
||||||
|
globalThis.removeFromStorage = (key) => {
|
||||||
|
localStorage.removeItem(key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// store info in local storage
|
|
||||||
globalThis.setStorage = (key, value) => {
|
|
||||||
localStorage.setItem(key, JSON.stringify(value))
|
|
||||||
}
|
|
||||||
globalThis.getStorage = (key, value) => {
|
|
||||||
return JSON.parse(localStorage.getItem(key))
|
|
||||||
}
|
|
||||||
globalThis.removeFromStorage = (key) => {
|
|
||||||
localStorage.removeItem(key)
|
|
||||||
}
|
|
||||||
|
|
99
libs/ktem/ktem/assets/js/pdf_viewer.js
Normal file
|
@ -0,0 +1,99 @@
|
||||||
|
function onBlockLoad () {
|
||||||
|
var infor_panel_scroll_pos = 0;
|
||||||
|
globalThis.createModal = () => {
|
||||||
|
// Create modal for the 1st time if it does not exist
|
||||||
|
var modal = document.getElementById("pdf-modal");
|
||||||
|
var old_position = null;
|
||||||
|
var old_width = null;
|
||||||
|
var old_left = null;
|
||||||
|
var expanded = false;
|
||||||
|
|
||||||
|
modal.id = "pdf-modal";
|
||||||
|
modal.className = "modal";
|
||||||
|
modal.innerHTML = `
|
||||||
|
<div class="modal-content">
|
||||||
|
<div class="modal-header">
|
||||||
|
<span class="close" id="modal-close">×</span>
|
||||||
|
<span class="close" id="modal-expand">⛶</span>
|
||||||
|
</div>
|
||||||
|
<div class="modal-body">
|
||||||
|
<pdfjs-viewer-element id="pdf-viewer" viewer-path="/file=PDFJS_PREBUILT_DIR" locale="en" phrase="true">
|
||||||
|
</pdfjs-viewer-element>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
`;
|
||||||
|
|
||||||
|
modal.querySelector("#modal-close").onclick = function() {
|
||||||
|
modal.style.display = "none";
|
||||||
|
var info_panel = document.getElementById("html-info-panel");
|
||||||
|
if (info_panel) {
|
||||||
|
info_panel.style.display = "block";
|
||||||
|
}
|
||||||
|
var scrollableDiv = document.getElementById("chat-info-panel");
|
||||||
|
scrollableDiv.scrollTop = infor_panel_scroll_pos;
|
||||||
|
};
|
||||||
|
|
||||||
|
modal.querySelector("#modal-expand").onclick = function () {
|
||||||
|
expanded = !expanded;
|
||||||
|
if (expanded) {
|
||||||
|
old_position = modal.style.position;
|
||||||
|
old_left = modal.style.left;
|
||||||
|
old_width = modal.style.width;
|
||||||
|
|
||||||
|
modal.style.position = "fixed";
|
||||||
|
modal.style.width = "70%";
|
||||||
|
modal.style.left = "15%";
|
||||||
|
} else {
|
||||||
|
modal.style.position = old_position;
|
||||||
|
modal.style.width = old_width;
|
||||||
|
modal.style.left = old_left;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Function to open modal and display PDF
|
||||||
|
globalThis.openModal = (event) => {
|
||||||
|
event.preventDefault();
|
||||||
|
var target = event.currentTarget;
|
||||||
|
var src = target.getAttribute("data-src");
|
||||||
|
var page = target.getAttribute("data-page");
|
||||||
|
var search = target.getAttribute("data-search");
|
||||||
|
var phrase = target.getAttribute("data-phrase");
|
||||||
|
|
||||||
|
var pdfViewer = document.getElementById("pdf-viewer");
|
||||||
|
|
||||||
|
current_src = pdfViewer.getAttribute("src");
|
||||||
|
if (current_src != src) {
|
||||||
|
pdfViewer.setAttribute("src", src);
|
||||||
|
}
|
||||||
|
pdfViewer.setAttribute("phrase", phrase);
|
||||||
|
pdfViewer.setAttribute("search", search);
|
||||||
|
pdfViewer.setAttribute("page", page);
|
||||||
|
|
||||||
|
var scrollableDiv = document.getElementById("chat-info-panel");
|
||||||
|
infor_panel_scroll_pos = scrollableDiv.scrollTop;
|
||||||
|
|
||||||
|
var modal = document.getElementById("pdf-modal")
|
||||||
|
modal.style.display = "block";
|
||||||
|
var info_panel = document.getElementById("html-info-panel");
|
||||||
|
if (info_panel) {
|
||||||
|
info_panel.style.display = "none";
|
||||||
|
}
|
||||||
|
scrollableDiv.scrollTop = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
globalThis.assignPdfOnclickEvent = () => {
|
||||||
|
// Get all links and attach click event
|
||||||
|
var links = document.getElementsByClassName("pdf-link");
|
||||||
|
for (var i = 0; i < links.length; i++) {
|
||||||
|
links[i].onclick = openModal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var created_modal = document.getElementById("pdf-viewer");
|
||||||
|
if (!created_modal) {
|
||||||
|
createModal();
|
||||||
|
console.log("Created modal")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -8,3 +8,6 @@ An open-source tool for you to chat with your documents.
|
||||||
[User Guide](https://cinnamon.github.io/kotaemon/) |
|
[User Guide](https://cinnamon.github.io/kotaemon/) |
|
||||||
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
|
[Developer Guide](https://cinnamon.github.io/kotaemon/development/) |
|
||||||
[Feedback](https://github.com/Cinnamon/kotaemon/issues)
|
[Feedback](https://github.com/Cinnamon/kotaemon/issues)
|
||||||
|
|
||||||
|
[Dark Mode](?__theme=dark)
|
||||||
|
[Night Mode](?__theme=light)
|
||||||
|
|
|
@ -136,6 +136,6 @@ Now navigate back to the `Chat` tab. The chat tab is divided into 3 regions:
|
||||||
files will be considered during chat.
|
files will be considered during chat.
|
||||||
2. Chat Panel
|
2. Chat Panel
|
||||||
- This is where you can chat with the chatbot.
|
- This is where you can chat with the chatbot.
|
||||||
3. Information panel
|
3. Information Panel
|
||||||
- Supporting information such as the retrieved evidence and reference will be
|
- Supporting information such as the retrieved evidence and reference will be
|
||||||
displayed here.
|
displayed here.
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import datetime
|
import datetime
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from sqlalchemy import JSON, Column
|
from sqlalchemy import JSON, Column
|
||||||
from sqlmodel import Field, SQLModel
|
from sqlmodel import Field, SQLModel
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
|
||||||
class BaseConversation(SQLModel):
|
class BaseConversation(SQLModel):
|
||||||
|
@ -24,10 +26,14 @@ class BaseConversation(SQLModel):
|
||||||
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
|
default_factory=lambda: uuid.uuid4().hex, primary_key=True, index=True
|
||||||
)
|
)
|
||||||
name: str = Field(
|
name: str = Field(
|
||||||
default_factory=lambda: datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S")
|
default_factory=lambda: datetime.datetime.now(
|
||||||
|
ZoneInfo(getattr(flowsettings, "TIME_ZONE", "UTC"))
|
||||||
|
).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
)
|
)
|
||||||
user: int = Field(default=0) # For now we only have one user
|
user: int = Field(default=0) # For now we only have one user
|
||||||
|
|
||||||
|
is_public: bool = Field(default=False)
|
||||||
|
|
||||||
# contains messages + current files
|
# contains messages + current files
|
||||||
data_source: dict = Field(default={}, sa_column=Column(JSON))
|
data_source: dict = Field(default={}, sa_column=Column(JSON))
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ class EmbeddingManager:
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the model pool from database"""
|
"""Load the model pool from database"""
|
||||||
self._models, self._info, self._defaut = {}, {}, ""
|
self._models, self._info, self._default = {}, {}, ""
|
||||||
with Session(engine) as sess:
|
with Session(engine) as sess:
|
||||||
stmt = select(EmbeddingTable)
|
stmt = select(EmbeddingTable)
|
||||||
items = sess.execute(stmt)
|
items = sess.execute(stmt)
|
||||||
|
|
|
@ -115,7 +115,7 @@ class EmbeddingManagement(BasePage):
|
||||||
"""Called when the app is created"""
|
"""Called when the app is created"""
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
self.list_embeddings,
|
self.list_embeddings,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.emb_list],
|
outputs=[self.emb_list],
|
||||||
)
|
)
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
|
@ -144,7 +144,7 @@ class EmbeddingManagement(BasePage):
|
||||||
self.create_emb,
|
self.create_emb,
|
||||||
inputs=[self.name, self.emb_choices, self.spec, self.default],
|
inputs=[self.name, self.emb_choices, self.spec, self.default],
|
||||||
outputs=None,
|
outputs=None,
|
||||||
).success(self.list_embeddings, inputs=None, outputs=[self.emb_list]).success(
|
).success(self.list_embeddings, inputs=[], outputs=[self.emb_list]).success(
|
||||||
lambda: ("", None, "", False, self.spec_desc_default),
|
lambda: ("", None, "", False, self.spec_desc_default),
|
||||||
outputs=[
|
outputs=[
|
||||||
self.name,
|
self.name,
|
||||||
|
@ -179,7 +179,7 @@ class EmbeddingManagement(BasePage):
|
||||||
)
|
)
|
||||||
self.btn_delete.click(
|
self.btn_delete.click(
|
||||||
self.on_btn_delete_click,
|
self.on_btn_delete_click,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
@ -190,7 +190,7 @@ class EmbeddingManagement(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.list_embeddings,
|
self.list_embeddings,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.emb_list],
|
outputs=[self.emb_list],
|
||||||
)
|
)
|
||||||
self.btn_delete_no.click(
|
self.btn_delete_no.click(
|
||||||
|
@ -199,7 +199,7 @@ class EmbeddingManagement(BasePage):
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
),
|
),
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
@ -213,7 +213,7 @@ class EmbeddingManagement(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.list_embeddings,
|
self.list_embeddings,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.emb_list],
|
outputs=[self.emb_list],
|
||||||
)
|
)
|
||||||
self.btn_close.click(
|
self.btn_close.click(
|
||||||
|
|
|
@ -54,6 +54,7 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
DS = Param(help="The DocStore")
|
DS = Param(help="The DocStore")
|
||||||
FSPath = Param(help="The file storage path")
|
FSPath = Param(help="The file storage path")
|
||||||
user_id = Param(help="The user id")
|
user_id = Param(help="The user id")
|
||||||
|
private = Param(False, help="Whether this is private index")
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||||
|
@ -73,7 +74,9 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||||
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
|
) -> Generator[
|
||||||
|
Document, None, tuple[list[str | None], list[str | None], list[Document]]
|
||||||
|
]:
|
||||||
"""Stream the indexing pipeline
|
"""Stream the indexing pipeline
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -87,6 +90,7 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
None if the indexing failed for that file path)
|
None if the indexing failed for that file path)
|
||||||
- the error messages (each error message corresponds to an input file path,
|
- the error messages (each error message corresponds to an input file path,
|
||||||
or None if the indexing was successful for that file path)
|
or None if the indexing was successful for that file path)
|
||||||
|
- the indexed documents in form of list[Documents]
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -149,3 +153,7 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
msg: the message to log
|
msg: the message to log
|
||||||
"""
|
"""
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
|
def rebuild_index(self):
|
||||||
|
"""Rebuild the index"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
3
libs/ktem/ktem/index/file/graph/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .graph_index import GraphRAGIndex
|
||||||
|
|
||||||
|
__all__ = ["GraphRAGIndex"]
|
36
libs/ktem/ktem/index/file/graph/graph_index.py
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ktem.index.file import FileIndex
|
||||||
|
|
||||||
|
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
|
from .pipelines import GraphRAGIndexingPipeline, GraphRAGRetrieverPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGIndex(FileIndex):
|
||||||
|
def _setup_indexing_cls(self):
|
||||||
|
self._indexing_pipeline_cls = GraphRAGIndexingPipeline
|
||||||
|
|
||||||
|
def _setup_retriever_cls(self):
|
||||||
|
self._retriever_pipeline_cls = [GraphRAGRetrieverPipeline]
|
||||||
|
|
||||||
|
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||||
|
"""Define the interface of the indexing pipeline"""
|
||||||
|
|
||||||
|
obj = super().get_indexing_pipeline(settings, user_id)
|
||||||
|
# disable vectorstore for this kind of Index
|
||||||
|
obj.VS = None
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def get_retriever_pipelines(
|
||||||
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
|
_, file_ids, _ = selected
|
||||||
|
retrievers = [
|
||||||
|
GraphRAGRetrieverPipeline(
|
||||||
|
file_ids=file_ids,
|
||||||
|
Index=self._resources["Index"],
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
return retrievers
|
359
libs/ktem/ktem/index/file/graph/pipelines.py
Normal file
|
@ -0,0 +1,359 @@
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from shutil import rmtree
|
||||||
|
from typing import Generator
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import tiktoken
|
||||||
|
from ktem.db.models import engine
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from theflow.settings import settings
|
||||||
|
|
||||||
|
from kotaemon.base import Document, Param, RetrievedDocument
|
||||||
|
|
||||||
|
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
|
||||||
|
from .visualize import create_knowledge_graph, visualize_graph
|
||||||
|
|
||||||
|
try:
|
||||||
|
from graphrag.query.context_builder.entity_extraction import EntityVectorStoreKey
|
||||||
|
from graphrag.query.indexer_adapters import (
|
||||||
|
read_indexer_entities,
|
||||||
|
read_indexer_relationships,
|
||||||
|
read_indexer_reports,
|
||||||
|
read_indexer_text_units,
|
||||||
|
)
|
||||||
|
from graphrag.query.input.loaders.dfs import store_entity_semantic_embeddings
|
||||||
|
from graphrag.query.llm.oai.embedding import OpenAIEmbedding
|
||||||
|
from graphrag.query.llm.oai.typing import OpenaiApiType
|
||||||
|
from graphrag.query.structured_search.local_search.mixed_context import (
|
||||||
|
LocalSearchMixedContext,
|
||||||
|
)
|
||||||
|
from graphrag.vector_stores.lancedb import LanceDBVectorStore
|
||||||
|
except ImportError:
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
"GraphRAG dependencies not installed. "
|
||||||
|
"GraphRAG retriever pipeline will not work properly."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
filestorage_path = Path(settings.KH_FILESTORAGE_PATH) / "graphrag"
|
||||||
|
filestorage_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_graph_index_path(graph_id: str):
|
||||||
|
root_path = Path(filestorage_path) / graph_id
|
||||||
|
input_path = root_path / "input"
|
||||||
|
|
||||||
|
return root_path, input_path
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGIndexingPipeline(IndexDocumentPipeline):
|
||||||
|
"""GraphRAG specific indexing pipeline"""
|
||||||
|
|
||||||
|
def route(self, file_path: Path) -> IndexPipeline:
|
||||||
|
"""Simply disable the splitter (chunking) for this pipeline"""
|
||||||
|
pipeline = super().route(file_path)
|
||||||
|
pipeline.splitter = None
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
def store_file_id_with_graph_id(self, file_ids: list[str | None]):
|
||||||
|
# create new graph_id and assign them to doc_id in self.Index
|
||||||
|
# record in the index
|
||||||
|
graph_id = str(uuid4())
|
||||||
|
with Session(engine) as session:
|
||||||
|
nodes = []
|
||||||
|
for file_id in file_ids:
|
||||||
|
if not file_id:
|
||||||
|
continue
|
||||||
|
nodes.append(
|
||||||
|
self.Index(
|
||||||
|
source_id=file_id,
|
||||||
|
target_id=graph_id,
|
||||||
|
relation_type="graph",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
session.add_all(nodes)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return graph_id
|
||||||
|
|
||||||
|
def write_docs_to_files(self, graph_id: str, docs: list[Document]):
|
||||||
|
root_path, input_path = prepare_graph_index_path(graph_id)
|
||||||
|
input_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
if doc.metadata.get("type", "text") == "text":
|
||||||
|
with open(input_path / f"{doc.doc_id}.txt", "w") as f:
|
||||||
|
f.write(doc.text)
|
||||||
|
|
||||||
|
return root_path
|
||||||
|
|
||||||
|
def call_graphrag_index(self, input_path: str):
|
||||||
|
# Construct the command
|
||||||
|
command = [
|
||||||
|
"python",
|
||||||
|
"-m",
|
||||||
|
"graphrag.index",
|
||||||
|
"--root",
|
||||||
|
input_path,
|
||||||
|
"--reporter",
|
||||||
|
"rich",
|
||||||
|
"--init",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Run the command
|
||||||
|
yield Document(
|
||||||
|
channel="debug",
|
||||||
|
text="[GraphRAG] Creating index... This can take a long time.",
|
||||||
|
)
|
||||||
|
result = subprocess.run(command, capture_output=True, text=True)
|
||||||
|
print(result.stdout)
|
||||||
|
command = command[:-1]
|
||||||
|
|
||||||
|
# Run the command and stream stdout
|
||||||
|
with subprocess.Popen(command, stdout=subprocess.PIPE, text=True) as process:
|
||||||
|
if process.stdout:
|
||||||
|
for line in process.stdout:
|
||||||
|
yield Document(channel="debug", text=line)
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||||
|
) -> Generator[
|
||||||
|
Document, None, tuple[list[str | None], list[str | None], list[Document]]
|
||||||
|
]:
|
||||||
|
file_ids, errors, all_docs = yield from super().stream(
|
||||||
|
file_paths, reindex=reindex, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# assign graph_id to file_ids
|
||||||
|
graph_id = self.store_file_id_with_graph_id(file_ids)
|
||||||
|
# call GraphRAG index with docs and graph_id
|
||||||
|
graph_index_path = self.write_docs_to_files(graph_id, all_docs)
|
||||||
|
yield from self.call_graphrag_index(graph_index_path)
|
||||||
|
|
||||||
|
return file_ids, errors, all_docs
|
||||||
|
|
||||||
|
|
||||||
|
class GraphRAGRetrieverPipeline(BaseFileIndexRetriever):
|
||||||
|
"""GraphRAG specific retriever pipeline"""
|
||||||
|
|
||||||
|
Index = Param(help="The SQLAlchemy Index table")
|
||||||
|
file_ids: list[str] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls) -> dict:
|
||||||
|
return {
|
||||||
|
"search_type": {
|
||||||
|
"name": "Search type",
|
||||||
|
"value": "local",
|
||||||
|
"choices": ["local", "global"],
|
||||||
|
"component": "dropdown",
|
||||||
|
"info": "Whether to use local or global search in the graph.",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_graph_search(self):
|
||||||
|
assert (
|
||||||
|
len(self.file_ids) <= 1
|
||||||
|
), "GraphRAG retriever only supports one file_id at a time"
|
||||||
|
|
||||||
|
file_id = self.file_ids[0]
|
||||||
|
# retrieve the graph_id from the index
|
||||||
|
with Session(engine) as session:
|
||||||
|
graph_id = (
|
||||||
|
session.query(self.Index.target_id)
|
||||||
|
.filter(self.Index.source_id == file_id)
|
||||||
|
.filter(self.Index.relation_type == "graph")
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
graph_id = graph_id[0] if graph_id else None
|
||||||
|
assert graph_id, f"GraphRAG index not found for file_id: {file_id}"
|
||||||
|
|
||||||
|
root_path, _ = prepare_graph_index_path(graph_id)
|
||||||
|
output_path = root_path / "output"
|
||||||
|
child_paths = sorted(
|
||||||
|
list(output_path.iterdir()), key=lambda x: x.stem, reverse=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# get the latest child path
|
||||||
|
assert child_paths, "GraphRAG index output not found"
|
||||||
|
latest_child_path = Path(child_paths[0]) / "artifacts"
|
||||||
|
|
||||||
|
INPUT_DIR = latest_child_path
|
||||||
|
LANCEDB_URI = str(INPUT_DIR / "lancedb")
|
||||||
|
COMMUNITY_REPORT_TABLE = "create_final_community_reports"
|
||||||
|
ENTITY_TABLE = "create_final_nodes"
|
||||||
|
ENTITY_EMBEDDING_TABLE = "create_final_entities"
|
||||||
|
RELATIONSHIP_TABLE = "create_final_relationships"
|
||||||
|
TEXT_UNIT_TABLE = "create_final_text_units"
|
||||||
|
COMMUNITY_LEVEL = 2
|
||||||
|
|
||||||
|
# read nodes table to get community and degree data
|
||||||
|
entity_df = pd.read_parquet(f"{INPUT_DIR}/{ENTITY_TABLE}.parquet")
|
||||||
|
entity_embedding_df = pd.read_parquet(
|
||||||
|
f"{INPUT_DIR}/{ENTITY_EMBEDDING_TABLE}.parquet"
|
||||||
|
)
|
||||||
|
entities = read_indexer_entities(
|
||||||
|
entity_df, entity_embedding_df, COMMUNITY_LEVEL
|
||||||
|
)
|
||||||
|
|
||||||
|
# load description embeddings to an in-memory lancedb vectorstore
|
||||||
|
# to connect to a remote db, specify url and port values.
|
||||||
|
description_embedding_store = LanceDBVectorStore(
|
||||||
|
collection_name="entity_description_embeddings",
|
||||||
|
)
|
||||||
|
description_embedding_store.connect(db_uri=LANCEDB_URI)
|
||||||
|
if Path(LANCEDB_URI).is_dir():
|
||||||
|
rmtree(LANCEDB_URI)
|
||||||
|
_ = store_entity_semantic_embeddings(
|
||||||
|
entities=entities, vectorstore=description_embedding_store
|
||||||
|
)
|
||||||
|
print(f"Entity count: {len(entity_df)}")
|
||||||
|
|
||||||
|
# Read relationships
|
||||||
|
relationship_df = pd.read_parquet(f"{INPUT_DIR}/{RELATIONSHIP_TABLE}.parquet")
|
||||||
|
relationships = read_indexer_relationships(relationship_df)
|
||||||
|
|
||||||
|
# Read community reports
|
||||||
|
report_df = pd.read_parquet(f"{INPUT_DIR}/{COMMUNITY_REPORT_TABLE}.parquet")
|
||||||
|
reports = read_indexer_reports(report_df, entity_df, COMMUNITY_LEVEL)
|
||||||
|
|
||||||
|
# Read text units
|
||||||
|
text_unit_df = pd.read_parquet(f"{INPUT_DIR}/{TEXT_UNIT_TABLE}.parquet")
|
||||||
|
text_units = read_indexer_text_units(text_unit_df)
|
||||||
|
|
||||||
|
embedding_model = os.getenv("GRAPHRAG_EMBEDDING_MODEL")
|
||||||
|
text_embedder = OpenAIEmbedding(
|
||||||
|
api_key=os.getenv("OPENAI_API_KEY"),
|
||||||
|
api_base=None,
|
||||||
|
api_type=OpenaiApiType.OpenAI,
|
||||||
|
model=embedding_model,
|
||||||
|
deployment_name=embedding_model,
|
||||||
|
max_retries=20,
|
||||||
|
)
|
||||||
|
token_encoder = tiktoken.get_encoding("cl100k_base")
|
||||||
|
|
||||||
|
context_builder = LocalSearchMixedContext(
|
||||||
|
community_reports=reports,
|
||||||
|
text_units=text_units,
|
||||||
|
entities=entities,
|
||||||
|
relationships=relationships,
|
||||||
|
covariates=None,
|
||||||
|
entity_text_embeddings=description_embedding_store,
|
||||||
|
embedding_vectorstore_key=EntityVectorStoreKey.ID,
|
||||||
|
# if the vectorstore uses entity title as ids,
|
||||||
|
# set this to EntityVectorStoreKey.TITLE
|
||||||
|
text_embedder=text_embedder,
|
||||||
|
token_encoder=token_encoder,
|
||||||
|
)
|
||||||
|
return context_builder
|
||||||
|
|
||||||
|
def _to_document(self, header: str, context_text: str) -> RetrievedDocument:
|
||||||
|
return RetrievedDocument(
|
||||||
|
text=context_text,
|
||||||
|
metadata={
|
||||||
|
"file_name": header,
|
||||||
|
"type": "table",
|
||||||
|
"llm_trulens_score": 1.0,
|
||||||
|
},
|
||||||
|
score=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_context_records(self, context_records) -> list[RetrievedDocument]:
|
||||||
|
entities = context_records.get("entities", [])
|
||||||
|
relationships = context_records.get("relationships", [])
|
||||||
|
reports = context_records.get("reports", [])
|
||||||
|
sources = context_records.get("sources", [])
|
||||||
|
|
||||||
|
docs = []
|
||||||
|
|
||||||
|
context: str = ""
|
||||||
|
|
||||||
|
header = "<b>Entities</b>\n"
|
||||||
|
context = entities[["entity", "description"]].to_markdown(index=False)
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
header = "\n<b>Relationships</b>\n"
|
||||||
|
context = relationships[["source", "target", "description"]].to_markdown(
|
||||||
|
index=False
|
||||||
|
)
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
header = "\n<b>Reports</b>\n"
|
||||||
|
context = ""
|
||||||
|
for idx, row in reports.iterrows():
|
||||||
|
title, content = row["title"], row["content"]
|
||||||
|
context += f"\n\n<h5>Report <b>{title}</b></h5>\n"
|
||||||
|
context += content
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
header = "\n<b>Sources</b>\n"
|
||||||
|
context = ""
|
||||||
|
for idx, row in sources.iterrows():
|
||||||
|
title, content = row["id"], row["text"]
|
||||||
|
context += f"\n\n<h5>Source <b>#{title}</b></h5>\n"
|
||||||
|
context += content
|
||||||
|
docs.append(self._to_document(header, context))
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
def plot_graph(self, context_records):
|
||||||
|
relationships = context_records.get("relationships", [])
|
||||||
|
G = create_knowledge_graph(relationships)
|
||||||
|
plot = visualize_graph(G)
|
||||||
|
return plot
|
||||||
|
|
||||||
|
def generate_relevant_scores(self, text, documents: list[RetrievedDocument]):
|
||||||
|
return documents
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
if not self.file_ids:
|
||||||
|
return []
|
||||||
|
context_builder = self._build_graph_search()
|
||||||
|
|
||||||
|
local_context_params = {
|
||||||
|
"text_unit_prop": 0.5,
|
||||||
|
"community_prop": 0.1,
|
||||||
|
"conversation_history_max_turns": 5,
|
||||||
|
"conversation_history_user_turns_only": True,
|
||||||
|
"top_k_mapped_entities": 10,
|
||||||
|
"top_k_relationships": 10,
|
||||||
|
"include_entity_rank": False,
|
||||||
|
"include_relationship_weight": False,
|
||||||
|
"include_community_rank": False,
|
||||||
|
"return_candidate_context": False,
|
||||||
|
"embedding_vectorstore_key": EntityVectorStoreKey.ID,
|
||||||
|
# set this to EntityVectorStoreKey.TITLE i
|
||||||
|
# f the vectorstore uses entity title as ids
|
||||||
|
"max_tokens": 12_000,
|
||||||
|
# change this based on the token limit you have on your model
|
||||||
|
# (if you are using a model with 8k limit, a good setting could be 5000)
|
||||||
|
}
|
||||||
|
|
||||||
|
context_text, context_records = context_builder.build_context(
|
||||||
|
query=text,
|
||||||
|
conversation_history=None,
|
||||||
|
**local_context_params,
|
||||||
|
)
|
||||||
|
documents = self.format_context_records(context_records)
|
||||||
|
plot = self.plot_graph(context_records)
|
||||||
|
|
||||||
|
return documents + [
|
||||||
|
RetrievedDocument(
|
||||||
|
text="",
|
||||||
|
metadata={
|
||||||
|
"file_name": "GraphRAG",
|
||||||
|
"type": "plot",
|
||||||
|
"data": plot,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
102
libs/ktem/ktem/index/file/graph/visualize.py
Normal file
|
@ -0,0 +1,102 @@
|
||||||
|
import networkx as nx
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
from plotly.io import to_json
|
||||||
|
|
||||||
|
|
||||||
|
def create_knowledge_graph(df):
|
||||||
|
"""
|
||||||
|
create nx Graph from DataFrame relations data
|
||||||
|
"""
|
||||||
|
G = nx.Graph()
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
source = row["source"]
|
||||||
|
target = row["target"]
|
||||||
|
attributes = {k: v for k, v in row.items() if k not in ["source", "target"]}
|
||||||
|
G.add_edge(source, target, **attributes)
|
||||||
|
|
||||||
|
return G
|
||||||
|
|
||||||
|
|
||||||
|
def visualize_graph(G):
|
||||||
|
pos = nx.spring_layout(G, dim=2)
|
||||||
|
|
||||||
|
edge_x = []
|
||||||
|
edge_y = []
|
||||||
|
edge_texts = nx.get_edge_attributes(G, "description")
|
||||||
|
to_display_edge_texts = []
|
||||||
|
for edge in G.edges():
|
||||||
|
x0, y0 = pos[edge[0]]
|
||||||
|
x1, y1 = pos[edge[1]]
|
||||||
|
edge_x.append(x0)
|
||||||
|
edge_x.append(x1)
|
||||||
|
edge_x.append(None)
|
||||||
|
edge_y.append(y0)
|
||||||
|
edge_y.append(y1)
|
||||||
|
edge_y.append(None)
|
||||||
|
to_display_edge_texts.append(edge_texts[edge])
|
||||||
|
|
||||||
|
edge_trace = go.Scatter(
|
||||||
|
x=edge_x,
|
||||||
|
y=edge_y,
|
||||||
|
text=to_display_edge_texts,
|
||||||
|
line=dict(width=0.5, color="#888"),
|
||||||
|
hoverinfo="text",
|
||||||
|
mode="lines",
|
||||||
|
)
|
||||||
|
|
||||||
|
node_x = []
|
||||||
|
node_y = []
|
||||||
|
for node in G.nodes():
|
||||||
|
x, y = pos[node]
|
||||||
|
node_x.append(x)
|
||||||
|
node_y.append(y)
|
||||||
|
|
||||||
|
node_adjacencies = []
|
||||||
|
node_text = []
|
||||||
|
node_size = []
|
||||||
|
for node_id, adjacencies in enumerate(G.adjacency()):
|
||||||
|
degree = len(adjacencies[1])
|
||||||
|
node_adjacencies.append(degree)
|
||||||
|
node_text.append(adjacencies[0])
|
||||||
|
node_size.append(15 if degree < 5 else (30 if degree < 10 else 60))
|
||||||
|
|
||||||
|
node_trace = go.Scatter(
|
||||||
|
x=node_x,
|
||||||
|
y=node_y,
|
||||||
|
textfont=dict(
|
||||||
|
family="Courier New, monospace",
|
||||||
|
size=10, # Set the font size here
|
||||||
|
),
|
||||||
|
textposition="top center",
|
||||||
|
mode="markers+text",
|
||||||
|
hoverinfo="text",
|
||||||
|
text=node_text,
|
||||||
|
marker=dict(
|
||||||
|
showscale=True,
|
||||||
|
# colorscale options
|
||||||
|
size=node_size,
|
||||||
|
colorscale="YlGnBu",
|
||||||
|
reversescale=True,
|
||||||
|
color=node_adjacencies,
|
||||||
|
colorbar=dict(
|
||||||
|
thickness=5,
|
||||||
|
xanchor="left",
|
||||||
|
titleside="right",
|
||||||
|
),
|
||||||
|
line_width=2,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = go.Figure(
|
||||||
|
data=[edge_trace, node_trace],
|
||||||
|
layout=go.Layout(
|
||||||
|
showlegend=False,
|
||||||
|
hovermode="closest",
|
||||||
|
margin=dict(b=20, l=5, r=5, t=40),
|
||||||
|
xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
||||||
|
yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
fig.update_layout(autosize=True)
|
||||||
|
|
||||||
|
return to_json(fig)
|
|
@ -4,8 +4,9 @@ from typing import Any, Optional, Type
|
||||||
from ktem.components import filestorage_path, get_docstore, get_vectorstore
|
from ktem.components import filestorage_path, get_docstore, get_vectorstore
|
||||||
from ktem.db.engine import engine
|
from ktem.db.engine import engine
|
||||||
from ktem.index.base import BaseIndex
|
from ktem.index.base import BaseIndex
|
||||||
from sqlalchemy import Column, DateTime, Integer, String
|
from sqlalchemy import JSON, Column, DateTime, Integer, String, UniqueConstraint
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.ext.mutable import MutableDict
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
from theflow.utils.modules import import_dotted_string
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
@ -52,27 +53,60 @@ class FileIndex(BaseIndex):
|
||||||
- File storage path
|
- File storage path
|
||||||
"""
|
"""
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
Source = type(
|
|
||||||
"Source",
|
if self.config.get("private", False):
|
||||||
(Base,),
|
Source = type(
|
||||||
{
|
"Source",
|
||||||
"__tablename__": f"index__{self.id}__source",
|
(Base,),
|
||||||
"id": Column(
|
{
|
||||||
String,
|
"__tablename__": f"index__{self.id}__source",
|
||||||
primary_key=True,
|
"__table_args__": (
|
||||||
default=lambda: str(uuid.uuid4()),
|
UniqueConstraint("name", "user", name="_name_user_uc"),
|
||||||
unique=True,
|
),
|
||||||
),
|
"id": Column(
|
||||||
"name": Column(String, unique=True),
|
String,
|
||||||
"path": Column(String),
|
primary_key=True,
|
||||||
"size": Column(Integer, default=0),
|
default=lambda: str(uuid.uuid4()),
|
||||||
"text_length": Column(Integer, default=0),
|
unique=True,
|
||||||
"date_created": Column(
|
),
|
||||||
DateTime(timezone=True), server_default=func.now()
|
"name": Column(String),
|
||||||
),
|
"path": Column(String),
|
||||||
"user": Column(Integer, default=1),
|
"size": Column(Integer, default=0),
|
||||||
},
|
"date_created": Column(
|
||||||
)
|
DateTime(timezone=True), server_default=func.now()
|
||||||
|
),
|
||||||
|
"user": Column(Integer, default=1),
|
||||||
|
"note": Column(
|
||||||
|
MutableDict.as_mutable(JSON), # type: ignore
|
||||||
|
default={},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
Source = type(
|
||||||
|
"Source",
|
||||||
|
(Base,),
|
||||||
|
{
|
||||||
|
"__tablename__": f"index__{self.id}__source",
|
||||||
|
"id": Column(
|
||||||
|
String,
|
||||||
|
primary_key=True,
|
||||||
|
default=lambda: str(uuid.uuid4()),
|
||||||
|
unique=True,
|
||||||
|
),
|
||||||
|
"name": Column(String, unique=True),
|
||||||
|
"path": Column(String),
|
||||||
|
"size": Column(Integer, default=0),
|
||||||
|
"date_created": Column(
|
||||||
|
DateTime(timezone=True), server_default=func.now()
|
||||||
|
),
|
||||||
|
"user": Column(Integer, default=1),
|
||||||
|
"note": Column(
|
||||||
|
MutableDict.as_mutable(JSON), # type: ignore
|
||||||
|
default={},
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
Index = type(
|
Index = type(
|
||||||
"IndexTable",
|
"IndexTable",
|
||||||
(Base,),
|
(Base,),
|
||||||
|
@ -85,6 +119,7 @@ class FileIndex(BaseIndex):
|
||||||
"user": Column(Integer, default=1),
|
"user": Column(Integer, default=1),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
|
self._vs: BaseVectorStore = get_vectorstore(f"index_{self.id}")
|
||||||
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
|
self._docstore: BaseDocumentStore = get_docstore(f"index_{self.id}")
|
||||||
self._fs_path = filestorage_path / f"index_{self.id}"
|
self._fs_path = filestorage_path / f"index_{self.id}"
|
||||||
|
@ -358,8 +393,6 @@ class FileIndex(BaseIndex):
|
||||||
for key, value in settings.items():
|
for key, value in settings.items():
|
||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
stripped_settings[key[len(prefix) :]] = value
|
stripped_settings[key[len(prefix) :]] = value
|
||||||
else:
|
|
||||||
stripped_settings[key] = value
|
|
||||||
|
|
||||||
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
||||||
obj.Source = self._resources["Source"]
|
obj.Source = self._resources["Source"]
|
||||||
|
@ -368,6 +401,7 @@ class FileIndex(BaseIndex):
|
||||||
obj.DS = self._docstore
|
obj.DS = self._docstore
|
||||||
obj.FSPath = self._fs_path
|
obj.FSPath = self._fs_path
|
||||||
obj.user_id = user_id
|
obj.user_id = user_id
|
||||||
|
obj.private = self.config.get("private", False)
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
@ -380,8 +414,6 @@ class FileIndex(BaseIndex):
|
||||||
for key, value in settings.items():
|
for key, value in settings.items():
|
||||||
if key.startswith(prefix):
|
if key.startswith(prefix):
|
||||||
stripped_settings[key[len(prefix) :]] = value
|
stripped_settings[key[len(prefix) :]] = value
|
||||||
else:
|
|
||||||
stripped_settings[key] = value
|
|
||||||
|
|
||||||
# transform selected id
|
# transform selected id
|
||||||
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
|
selected_ids: Optional[list[str]] = self._selector_ui.get_selected_ids(selected)
|
||||||
|
|
3
libs/ktem/ktem/index/file/knet/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .knet_index import KnowledgeNetworkFileIndex
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeNetworkFileIndex"]
|
47
libs/ktem/ktem/index/file/knet/knet_index.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from ktem.index.file import FileIndex
|
||||||
|
|
||||||
|
from ..base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
|
from .pipelines import KnetIndexingPipeline, KnetRetrievalPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeNetworkFileIndex(FileIndex):
|
||||||
|
@classmethod
|
||||||
|
def get_admin_settings(cls):
|
||||||
|
admin_settings = super().get_admin_settings()
|
||||||
|
|
||||||
|
# remove embedding from admin settings
|
||||||
|
# as we don't need it
|
||||||
|
admin_settings.pop("embedding")
|
||||||
|
return admin_settings
|
||||||
|
|
||||||
|
def _setup_indexing_cls(self):
|
||||||
|
self._indexing_pipeline_cls = KnetIndexingPipeline
|
||||||
|
|
||||||
|
def _setup_retriever_cls(self):
|
||||||
|
self._retriever_pipeline_cls = [KnetRetrievalPipeline]
|
||||||
|
|
||||||
|
def get_indexing_pipeline(self, settings, user_id) -> BaseFileIndexIndexing:
|
||||||
|
"""Define the interface of the indexing pipeline"""
|
||||||
|
|
||||||
|
obj = super().get_indexing_pipeline(settings, user_id)
|
||||||
|
# disable vectorstore for this kind of Index
|
||||||
|
# also set the collection_name for API call
|
||||||
|
obj.VS = None
|
||||||
|
obj.collection_name = f"kh_index_{self.id}"
|
||||||
|
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def get_retriever_pipelines(
|
||||||
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
|
retrievers = super().get_retriever_pipelines(settings, user_id, selected)
|
||||||
|
|
||||||
|
for obj in retrievers:
|
||||||
|
# disable vectorstore for this kind of Index
|
||||||
|
# also set the collection_name for API call
|
||||||
|
obj.VS = None
|
||||||
|
obj.collection_name = f"kh_index_{self.id}"
|
||||||
|
|
||||||
|
return retrievers
|
169
libs/ktem/ktem/index/file/knet/pipelines.py
Normal file
|
@ -0,0 +1,169 @@
|
||||||
|
import base64
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from kotaemon.base import RetrievedDocument
|
||||||
|
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
|
||||||
|
|
||||||
|
from ..pipelines import BaseFileIndexRetriever, IndexDocumentPipeline, IndexPipeline
|
||||||
|
|
||||||
|
|
||||||
|
class KnetIndexingPipeline(IndexDocumentPipeline):
|
||||||
|
"""Knowledge Network specific indexing pipeline"""
|
||||||
|
|
||||||
|
# collection name for external indexing call
|
||||||
|
collection_name: str = "default"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls):
|
||||||
|
return {
|
||||||
|
"reader_mode": {
|
||||||
|
"name": "Index parser",
|
||||||
|
"value": "knowledge_network",
|
||||||
|
"choices": [
|
||||||
|
("Default (KN)", "knowledge_network"),
|
||||||
|
],
|
||||||
|
"component": "dropdown",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def route(self, file_path: Path) -> IndexPipeline:
|
||||||
|
"""Simply disable the splitter (chunking) for this pipeline"""
|
||||||
|
pipeline = super().route(file_path)
|
||||||
|
pipeline.splitter = None
|
||||||
|
# assign IndexPipeline collection name to parse to loader
|
||||||
|
pipeline.collection_name = self.collection_name
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
class KnetRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
|
DEFAULT_KNET_ENDPOINT: str = "http://127.0.0.1:8081/retrieve"
|
||||||
|
|
||||||
|
collection_name: str = "default"
|
||||||
|
rerankers: Sequence[BaseReranking] = [LLMReranking.withx()]
|
||||||
|
|
||||||
|
def encode_image_base64(self, image_path: str | Path) -> bytes | str:
|
||||||
|
"""Convert image to base64"""
|
||||||
|
img_base64 = "data:image/png;base64,{}"
|
||||||
|
with open(image_path, "rb") as image_file:
|
||||||
|
return img_base64.format(
|
||||||
|
base64.b64encode(image_file.read()).decode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
doc_ids: Optional[list[str]] = None,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
"""Retrieve document excerpts similar to the text
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: the text to retrieve similar documents
|
||||||
|
doc_ids: list of document ids to constraint the retrieval
|
||||||
|
"""
|
||||||
|
print("searching in doc_ids", doc_ids)
|
||||||
|
if not doc_ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
docs: list[RetrievedDocument] = []
|
||||||
|
params = {
|
||||||
|
"query": text,
|
||||||
|
"collection": self.collection_name,
|
||||||
|
"meta_filters": {"doc_name": doc_ids},
|
||||||
|
}
|
||||||
|
params["meta_filters"] = json.dumps(params["meta_filters"])
|
||||||
|
response = requests.get(self.DEFAULT_KNET_ENDPOINT, params=params)
|
||||||
|
metadata_translation = {
|
||||||
|
"TABLE": "table",
|
||||||
|
"FIGURE": "image",
|
||||||
|
}
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
# Load YAML content from the response content
|
||||||
|
chunks = yaml.safe_load(response.content)
|
||||||
|
for chunk in chunks:
|
||||||
|
metadata = chunk["node"]["metadata"]
|
||||||
|
metadata["type"] = metadata_translation.get(
|
||||||
|
metadata.pop("content_type", ""), ""
|
||||||
|
)
|
||||||
|
metadata["file_name"] = metadata.pop("company_name", "")
|
||||||
|
|
||||||
|
# load image from returned path
|
||||||
|
image_path = metadata.get("image_path", "")
|
||||||
|
if image_path and os.path.isfile(image_path):
|
||||||
|
base64_im = self.encode_image_base64(image_path)
|
||||||
|
# explicitly set document type
|
||||||
|
metadata["type"] = "image"
|
||||||
|
metadata["image_origin"] = base64_im
|
||||||
|
|
||||||
|
docs.append(
|
||||||
|
RetrievedDocument(text=chunk["node"]["text"], metadata=metadata)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise IOError(f"{response.status_code}: {response.text}")
|
||||||
|
|
||||||
|
for reranker in self.rerankers:
|
||||||
|
docs = reranker(documents=docs, query=text)
|
||||||
|
|
||||||
|
return docs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls) -> dict:
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
|
try:
|
||||||
|
reranking_llm = llms.get_default_name()
|
||||||
|
reranking_llm_choices = list(llms.options().keys())
|
||||||
|
except Exception:
|
||||||
|
reranking_llm = None
|
||||||
|
reranking_llm_choices = []
|
||||||
|
|
||||||
|
return {
|
||||||
|
"reranking_llm": {
|
||||||
|
"name": "LLM for scoring",
|
||||||
|
"value": reranking_llm,
|
||||||
|
"component": "dropdown",
|
||||||
|
"choices": reranking_llm_choices,
|
||||||
|
"special_type": "llm",
|
||||||
|
},
|
||||||
|
"retrieval_mode": {
|
||||||
|
"name": "Retrieval mode",
|
||||||
|
"value": "hybrid",
|
||||||
|
"choices": ["vector", "text", "hybrid"],
|
||||||
|
"component": "dropdown",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline(cls, user_settings, index_settings, selected):
|
||||||
|
"""Get retriever objects associated with the index
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: the settings of the app
|
||||||
|
kwargs: other arguments
|
||||||
|
"""
|
||||||
|
from ktem.llms.manager import llms
|
||||||
|
|
||||||
|
retriever = cls(
|
||||||
|
rerankers=[LLMTrulensScoring()],
|
||||||
|
)
|
||||||
|
|
||||||
|
# hacky way to input doc_ids to retriever.run() call (through theflow)
|
||||||
|
kwargs = {".doc_ids": selected}
|
||||||
|
retriever.set_run(kwargs, temp=False)
|
||||||
|
|
||||||
|
for reranker in retriever.rerankers:
|
||||||
|
if isinstance(reranker, LLMReranking):
|
||||||
|
reranker.llm = llms.get(
|
||||||
|
user_settings["reranking_llm"], llms.get_default()
|
||||||
|
)
|
||||||
|
|
||||||
|
return retriever
|
|
@ -2,25 +2,29 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from copy import deepcopy
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator, Optional
|
from typing import Generator, Optional, Sequence
|
||||||
|
|
||||||
|
import tiktoken
|
||||||
from ktem.db.models import engine
|
from ktem.db.models import engine
|
||||||
from ktem.embeddings.manager import embedding_models_manager
|
from ktem.embeddings.manager import embedding_models_manager
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.core.readers.base import BaseReader
|
||||||
from llama_index.readers.file.base import default_file_metadata_func
|
from llama_index.core.readers.file.base import default_file_metadata_func
|
||||||
from llama_index.vector_stores import (
|
from llama_index.core.vector_stores import (
|
||||||
FilterCondition,
|
FilterCondition,
|
||||||
FilterOperator,
|
FilterOperator,
|
||||||
MetadataFilter,
|
MetadataFilter,
|
||||||
MetadataFilters,
|
MetadataFilters,
|
||||||
)
|
)
|
||||||
from llama_index.vector_stores.types import VectorStoreQueryMode
|
from llama_index.core.vector_stores.types import VectorStoreQueryMode
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from theflow.settings import settings
|
from theflow.settings import settings
|
||||||
|
@ -29,8 +33,18 @@ from theflow.utils.modules import import_dotted_string
|
||||||
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
|
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
|
||||||
from kotaemon.embeddings import BaseEmbeddings
|
from kotaemon.embeddings import BaseEmbeddings
|
||||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
from kotaemon.indices.ingests.files import (
|
||||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking
|
KH_DEFAULT_FILE_EXTRACTORS,
|
||||||
|
adobe_reader,
|
||||||
|
azure_reader,
|
||||||
|
unstructured,
|
||||||
|
)
|
||||||
|
from kotaemon.indices.rankings import (
|
||||||
|
BaseReranking,
|
||||||
|
CohereReranking,
|
||||||
|
LLMReranking,
|
||||||
|
LLMTrulensScoring,
|
||||||
|
)
|
||||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||||
|
|
||||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
|
@ -60,6 +74,9 @@ def dev_settings():
|
||||||
return file_extractors, chunk_size, chunk_overlap
|
return file_extractors, chunk_size, chunk_overlap
|
||||||
|
|
||||||
|
|
||||||
|
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
|
||||||
|
|
||||||
|
|
||||||
class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
"""Retrieve relevant document
|
"""Retrieve relevant document
|
||||||
|
|
||||||
|
@ -75,10 +92,13 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
embedding: BaseEmbeddings
|
embedding: BaseEmbeddings
|
||||||
reranker: BaseReranking = LLMReranking.withx()
|
rerankers: Sequence[BaseReranking] = []
|
||||||
|
# use LLM to create relevant scores for displaying on UI
|
||||||
|
llm_scorer: LLMReranking | None = LLMReranking.withx()
|
||||||
get_extra_table: bool = False
|
get_extra_table: bool = False
|
||||||
mmr: bool = False
|
mmr: bool = False
|
||||||
top_k: int = 5
|
top_k: int = 5
|
||||||
|
retrieval_mode: str = "hybrid"
|
||||||
|
|
||||||
@Node.auto(depends_on=["embedding", "VS", "DS"])
|
@Node.auto(depends_on=["embedding", "VS", "DS"])
|
||||||
def vector_retrieval(self) -> VectorRetrieval:
|
def vector_retrieval(self) -> VectorRetrieval:
|
||||||
|
@ -86,6 +106,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
embedding=self.embedding,
|
embedding=self.embedding,
|
||||||
vector_store=self.VS,
|
vector_store=self.VS,
|
||||||
doc_store=self.DS,
|
doc_store=self.DS,
|
||||||
|
retrieval_mode=self.retrieval_mode, # type: ignore
|
||||||
|
rerankers=self.rerankers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
|
@ -101,27 +123,30 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
text: the text to retrieve similar documents
|
text: the text to retrieve similar documents
|
||||||
doc_ids: list of document ids to constraint the retrieval
|
doc_ids: list of document ids to constraint the retrieval
|
||||||
"""
|
"""
|
||||||
|
print("searching in doc_ids", doc_ids)
|
||||||
if not doc_ids:
|
if not doc_ids:
|
||||||
logger.info(f"Skip retrieval because of no selected files: {self}")
|
logger.info(f"Skip retrieval because of no selected files: {self}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
retrieval_kwargs = {}
|
retrieval_kwargs: dict = {}
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
stmt = select(self.Index).where(
|
stmt = select(self.Index).where(
|
||||||
self.Index.relation_type == "vector",
|
self.Index.relation_type == "document",
|
||||||
self.Index.source_id.in_(doc_ids),
|
self.Index.source_id.in_(doc_ids),
|
||||||
)
|
)
|
||||||
results = session.execute(stmt)
|
results = session.execute(stmt)
|
||||||
vs_ids = [r[0].target_id for r in results.all()]
|
chunk_ids = [r[0].target_id for r in results.all()]
|
||||||
|
|
||||||
|
# do first round top_k extension
|
||||||
|
retrieval_kwargs["do_extend"] = True
|
||||||
|
retrieval_kwargs["scope"] = chunk_ids
|
||||||
retrieval_kwargs["filters"] = MetadataFilters(
|
retrieval_kwargs["filters"] = MetadataFilters(
|
||||||
filters=[
|
filters=[
|
||||||
MetadataFilter(
|
MetadataFilter(
|
||||||
key="doc_id",
|
key="file_id",
|
||||||
value=vs_id,
|
value=doc_ids,
|
||||||
operator=FilterOperator.EQ,
|
operator=FilterOperator.IN,
|
||||||
)
|
)
|
||||||
for vs_id in vs_ids
|
|
||||||
],
|
],
|
||||||
condition=FilterCondition.OR,
|
condition=FilterCondition.OR,
|
||||||
)
|
)
|
||||||
|
@ -132,9 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
retrieval_kwargs["mmr_threshold"] = 0.5
|
retrieval_kwargs["mmr_threshold"] = 0.5
|
||||||
|
|
||||||
# rerank
|
# rerank
|
||||||
|
s_time = time.time()
|
||||||
|
print(f"retrieval_kwargs: {retrieval_kwargs.keys()}")
|
||||||
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
|
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
|
||||||
if docs and self.get_from_path("reranker"):
|
print("retrieval step took", time.time() - s_time)
|
||||||
docs = self.reranker(docs, query=text)
|
|
||||||
|
|
||||||
if not self.get_extra_table:
|
if not self.get_extra_table:
|
||||||
return docs
|
return docs
|
||||||
|
@ -157,17 +183,30 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
for fn, pls in table_pages.items()
|
for fn, pls in table_pages.items()
|
||||||
]
|
]
|
||||||
if queries:
|
if queries:
|
||||||
extra_docs = self.vector_retrieval(
|
try:
|
||||||
text="",
|
extra_docs = self.vector_retrieval(
|
||||||
top_k=50,
|
text="",
|
||||||
where=queries[0] if len(queries) == 1 else {"$or": queries},
|
top_k=50,
|
||||||
)
|
where=queries[0] if len(queries) == 1 else {"$or": queries},
|
||||||
for doc in extra_docs:
|
)
|
||||||
if doc.doc_id not in retrieved_id:
|
for doc in extra_docs:
|
||||||
docs.append(doc)
|
if doc.doc_id not in retrieved_id:
|
||||||
|
docs.append(doc)
|
||||||
|
except Exception:
|
||||||
|
print("Error retrieving additional tables")
|
||||||
|
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
def generate_relevant_scores(
|
||||||
|
self, query: str, documents: list[RetrievedDocument]
|
||||||
|
) -> list[RetrievedDocument]:
|
||||||
|
docs = (
|
||||||
|
documents
|
||||||
|
if not self.llm_scorer
|
||||||
|
else self.llm_scorer(documents=documents, query=query)
|
||||||
|
)
|
||||||
|
return docs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
|
@ -182,43 +221,44 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"reranking_llm": {
|
"reranking_llm": {
|
||||||
"name": "LLM for reranking",
|
"name": "LLM for relevant scoring",
|
||||||
"value": reranking_llm,
|
"value": reranking_llm,
|
||||||
"component": "dropdown",
|
"component": "dropdown",
|
||||||
"choices": reranking_llm_choices,
|
"choices": reranking_llm_choices,
|
||||||
},
|
"special_type": "llm",
|
||||||
"separate_embedding": {
|
|
||||||
"name": "Use separate embedding",
|
|
||||||
"value": False,
|
|
||||||
"choices": [("Yes", True), ("No", False)],
|
|
||||||
"component": "dropdown",
|
|
||||||
},
|
},
|
||||||
"num_retrieval": {
|
"num_retrieval": {
|
||||||
"name": "Number of document chunks to retrieve",
|
"name": "Number of document chunks to retrieve",
|
||||||
"value": 3,
|
"value": 10,
|
||||||
"component": "number",
|
"component": "number",
|
||||||
},
|
},
|
||||||
"retrieval_mode": {
|
"retrieval_mode": {
|
||||||
"name": "Retrieval mode",
|
"name": "Retrieval mode",
|
||||||
"value": "vector",
|
"value": "hybrid",
|
||||||
"choices": ["vector", "text", "hybrid"],
|
"choices": ["vector", "text", "hybrid"],
|
||||||
"component": "dropdown",
|
"component": "dropdown",
|
||||||
},
|
},
|
||||||
"prioritize_table": {
|
"prioritize_table": {
|
||||||
"name": "Prioritize table",
|
"name": "Prioritize table",
|
||||||
"value": True,
|
"value": False,
|
||||||
"choices": [True, False],
|
"choices": [True, False],
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
"mmr": {
|
"mmr": {
|
||||||
"name": "Use MMR",
|
"name": "Use MMR",
|
||||||
"value": True,
|
"value": False,
|
||||||
"choices": [True, False],
|
"choices": [True, False],
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
"use_reranking": {
|
"use_reranking": {
|
||||||
"name": "Use reranking",
|
"name": "Use reranking",
|
||||||
"value": False,
|
"value": True,
|
||||||
|
"choices": [True, False],
|
||||||
|
"component": "checkbox",
|
||||||
|
},
|
||||||
|
"use_llm_reranking": {
|
||||||
|
"name": "Use LLM relevant scoring",
|
||||||
|
"value": True,
|
||||||
"choices": [True, False],
|
"choices": [True, False],
|
||||||
"component": "checkbox",
|
"component": "checkbox",
|
||||||
},
|
},
|
||||||
|
@ -232,6 +272,8 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
settings: the settings of the app
|
settings: the settings of the app
|
||||||
kwargs: other arguments
|
kwargs: other arguments
|
||||||
"""
|
"""
|
||||||
|
use_llm_reranking = user_settings.get("use_llm_reranking", False)
|
||||||
|
|
||||||
retriever = cls(
|
retriever = cls(
|
||||||
get_extra_table=user_settings["prioritize_table"],
|
get_extra_table=user_settings["prioritize_table"],
|
||||||
top_k=user_settings["num_retrieval"],
|
top_k=user_settings["num_retrieval"],
|
||||||
|
@ -241,16 +283,26 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
"embedding", embedding_models_manager.get_default_name()
|
"embedding", embedding_models_manager.get_default_name()
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
retrieval_mode=user_settings["retrieval_mode"],
|
||||||
|
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
|
||||||
|
rerankers=[CohereReranking()],
|
||||||
)
|
)
|
||||||
if not user_settings["use_reranking"]:
|
if not user_settings["use_reranking"]:
|
||||||
retriever.reranker = None # type: ignore
|
retriever.rerankers = [] # type: ignore
|
||||||
else:
|
|
||||||
retriever.reranker.llm = llms.get(
|
for reranker in retriever.rerankers:
|
||||||
|
if isinstance(reranker, LLMReranking):
|
||||||
|
reranker.llm = llms.get(
|
||||||
|
user_settings["reranking_llm"], llms.get_default()
|
||||||
|
)
|
||||||
|
|
||||||
|
if retriever.llm_scorer:
|
||||||
|
retriever.llm_scorer.llm = llms.get(
|
||||||
user_settings["reranking_llm"], llms.get_default()
|
user_settings["reranking_llm"], llms.get_default()
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = {".doc_ids": selected}
|
kwargs = {".doc_ids": selected}
|
||||||
retriever.set_run(kwargs, temp=True)
|
retriever.set_run(kwargs, temp=False)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
|
|
||||||
|
@ -258,8 +310,8 @@ class IndexPipeline(BaseComponent):
|
||||||
"""Index a single file"""
|
"""Index a single file"""
|
||||||
|
|
||||||
loader: BaseReader
|
loader: BaseReader
|
||||||
splitter: BaseSplitter
|
splitter: BaseSplitter | None
|
||||||
chunk_batch_size: int = 50
|
chunk_batch_size: int = 200
|
||||||
|
|
||||||
Source = Param(help="The SQLAlchemy Source table")
|
Source = Param(help="The SQLAlchemy Source table")
|
||||||
Index = Param(help="The SQLAlchemy Index table")
|
Index = Param(help="The SQLAlchemy Index table")
|
||||||
|
@ -267,6 +319,9 @@ class IndexPipeline(BaseComponent):
|
||||||
DS = Param(help="The DocStore")
|
DS = Param(help="The DocStore")
|
||||||
FSPath = Param(help="The file storage path")
|
FSPath = Param(help="The file storage path")
|
||||||
user_id = Param(help="The user id")
|
user_id = Param(help="The user id")
|
||||||
|
collection_name: str = "default"
|
||||||
|
private: bool = False
|
||||||
|
run_embedding_in_thread: bool = False
|
||||||
embedding: BaseEmbeddings
|
embedding: BaseEmbeddings
|
||||||
|
|
||||||
@Node.auto(depends_on=["Source", "Index", "embedding"])
|
@Node.auto(depends_on=["Source", "Index", "embedding"])
|
||||||
|
@ -276,31 +331,81 @@ class IndexPipeline(BaseComponent):
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
|
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
|
||||||
|
s_time = time.time()
|
||||||
|
text_docs = []
|
||||||
|
non_text_docs = []
|
||||||
|
thumbnail_docs = []
|
||||||
|
|
||||||
|
for doc in docs:
|
||||||
|
doc_type = doc.metadata.get("type", "text")
|
||||||
|
if doc_type == "text":
|
||||||
|
text_docs.append(doc)
|
||||||
|
elif doc_type == "thumbnail":
|
||||||
|
thumbnail_docs.append(doc)
|
||||||
|
else:
|
||||||
|
non_text_docs.append(doc)
|
||||||
|
|
||||||
|
print(f"Got {len(thumbnail_docs)} page thumbnails")
|
||||||
|
page_label_to_thumbnail = {
|
||||||
|
doc.metadata["page_label"]: doc.doc_id for doc in thumbnail_docs
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.splitter:
|
||||||
|
all_chunks = self.splitter(text_docs)
|
||||||
|
else:
|
||||||
|
all_chunks = text_docs
|
||||||
|
|
||||||
|
# add the thumbnails doc_id to the chunks
|
||||||
|
for chunk in all_chunks:
|
||||||
|
page_label = chunk.metadata.get("page_label", None)
|
||||||
|
if page_label and page_label in page_label_to_thumbnail:
|
||||||
|
chunk.metadata["thumbnail_doc_id"] = page_label_to_thumbnail[page_label]
|
||||||
|
|
||||||
|
to_index_chunks = all_chunks + non_text_docs + thumbnail_docs
|
||||||
|
|
||||||
|
# add to doc store
|
||||||
chunks = []
|
chunks = []
|
||||||
n_chunks = 0
|
n_chunks = 0
|
||||||
for cidx, chunk in enumerate(self.splitter(docs)):
|
chunk_size = self.chunk_batch_size * 4
|
||||||
chunks.append(chunk)
|
for start_idx in range(0, len(to_index_chunks), chunk_size):
|
||||||
if cidx % self.chunk_batch_size == 0:
|
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
|
||||||
self.handle_chunks(chunks, file_id)
|
self.handle_chunks_docstore(chunks, file_id)
|
||||||
n_chunks += len(chunks)
|
|
||||||
chunks = []
|
|
||||||
yield Document(
|
|
||||||
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
|
|
||||||
)
|
|
||||||
|
|
||||||
if chunks:
|
|
||||||
self.handle_chunks(chunks, file_id)
|
|
||||||
n_chunks += len(chunks)
|
n_chunks += len(chunks)
|
||||||
yield Document(
|
yield Document(
|
||||||
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
|
f" => [{file_name}] Processed {n_chunks} chunks",
|
||||||
|
channel="debug",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def insert_chunks_to_vectorstore():
|
||||||
|
chunks = []
|
||||||
|
n_chunks = 0
|
||||||
|
chunk_size = self.chunk_batch_size
|
||||||
|
for start_idx in range(0, len(to_index_chunks), chunk_size):
|
||||||
|
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
|
||||||
|
self.handle_chunks_vectorstore(chunks, file_id)
|
||||||
|
n_chunks += len(chunks)
|
||||||
|
if self.VS:
|
||||||
|
yield Document(
|
||||||
|
f" => [{file_name}] Created embedding for {n_chunks} chunks",
|
||||||
|
channel="debug",
|
||||||
|
)
|
||||||
|
|
||||||
|
# run vector indexing in thread if specified
|
||||||
|
if self.run_embedding_in_thread:
|
||||||
|
print("Running embedding in thread")
|
||||||
|
threading.Thread(
|
||||||
|
target=lambda: list(insert_chunks_to_vectorstore())
|
||||||
|
).start()
|
||||||
|
else:
|
||||||
|
yield from insert_chunks_to_vectorstore()
|
||||||
|
|
||||||
|
print("indexing step took", time.time() - s_time)
|
||||||
return n_chunks
|
return n_chunks
|
||||||
|
|
||||||
def handle_chunks(self, chunks, file_id):
|
def handle_chunks_docstore(self, chunks, file_id):
|
||||||
"""Run chunks"""
|
"""Run chunks"""
|
||||||
# run embedding, add to both vector store and doc store
|
# run embedding, add to both vector store and doc store
|
||||||
self.vector_indexing(chunks)
|
self.vector_indexing.add_to_docstore(chunks)
|
||||||
|
|
||||||
# record in the index
|
# record in the index
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
|
@ -313,16 +418,30 @@ class IndexPipeline(BaseComponent):
|
||||||
relation_type="document",
|
relation_type="document",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
nodes.append(
|
|
||||||
self.Index(
|
|
||||||
source_id=file_id,
|
|
||||||
target_id=chunk.doc_id,
|
|
||||||
relation_type="vector",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
session.add_all(nodes)
|
session.add_all(nodes)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
def handle_chunks_vectorstore(self, chunks, file_id):
|
||||||
|
"""Run chunks"""
|
||||||
|
# run embedding, add to both vector store and doc store
|
||||||
|
self.vector_indexing.add_to_vectorstore(chunks)
|
||||||
|
self.vector_indexing.write_chunk_to_file(chunks)
|
||||||
|
|
||||||
|
if self.VS:
|
||||||
|
# record in the index
|
||||||
|
with Session(engine) as session:
|
||||||
|
nodes = []
|
||||||
|
for chunk in chunks:
|
||||||
|
nodes.append(
|
||||||
|
self.Index(
|
||||||
|
source_id=file_id,
|
||||||
|
target_id=chunk.doc_id,
|
||||||
|
relation_type="vector",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.add_all(nodes)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
def get_id_if_exists(self, file_path: Path) -> Optional[str]:
|
def get_id_if_exists(self, file_path: Path) -> Optional[str]:
|
||||||
"""Check if the file is already indexed
|
"""Check if the file is already indexed
|
||||||
|
|
||||||
|
@ -332,8 +451,16 @@ class IndexPipeline(BaseComponent):
|
||||||
Returns:
|
Returns:
|
||||||
the file id if the file is indexed, otherwise None
|
the file id if the file is indexed, otherwise None
|
||||||
"""
|
"""
|
||||||
|
if self.private:
|
||||||
|
cond: tuple = (
|
||||||
|
self.Source.name == file_path.name,
|
||||||
|
self.Source.user == self.user_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cond = (self.Source.name == file_path.name,)
|
||||||
|
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
stmt = select(self.Source).where(self.Source.name == file_path.name)
|
stmt = select(self.Source).where(*cond)
|
||||||
item = session.execute(stmt).first()
|
item = session.execute(stmt).first()
|
||||||
if item:
|
if item:
|
||||||
return item[0].id
|
return item[0].id
|
||||||
|
@ -369,20 +496,36 @@ class IndexPipeline(BaseComponent):
|
||||||
def finish(self, file_id: str, file_path: Path) -> str:
|
def finish(self, file_id: str, file_path: Path) -> str:
|
||||||
"""Finish the indexing"""
|
"""Finish the indexing"""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
stmt = select(self.Index.target_id).where(self.Index.source_id == file_id)
|
stmt = select(self.Source).where(self.Source.id == file_id)
|
||||||
doc_ids = [_[0] for _ in session.execute(stmt)]
|
result = session.execute(stmt).first()
|
||||||
if doc_ids:
|
if not result:
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
item = result[0]
|
||||||
|
|
||||||
|
# populate the number of tokens
|
||||||
|
doc_ids_stmt = select(self.Index.target_id).where(
|
||||||
|
self.Index.source_id == file_id,
|
||||||
|
self.Index.relation_type == "document",
|
||||||
|
)
|
||||||
|
doc_ids = [_[0] for _ in session.execute(doc_ids_stmt)]
|
||||||
|
token_func = self.get_token_func()
|
||||||
|
if doc_ids and token_func:
|
||||||
docs = self.DS.get(doc_ids)
|
docs = self.DS.get(doc_ids)
|
||||||
stmt = select(self.Source).where(self.Source.id == file_id)
|
item.note["tokens"] = sum([len(token_func(doc.text)) for doc in docs])
|
||||||
result = session.execute(stmt).first()
|
|
||||||
if result:
|
# populate the note
|
||||||
item = result[0]
|
item.note["loader"] = self.get_from_path("loader").__class__.__name__
|
||||||
item.text_length = sum([len(doc.text) for doc in docs])
|
|
||||||
session.add(item)
|
session.add(item)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return file_id
|
return file_id
|
||||||
|
|
||||||
|
def get_token_func(self):
|
||||||
|
"""Get the token function for calculating the number of tokens"""
|
||||||
|
return _default_token_func
|
||||||
|
|
||||||
def delete_file(self, file_id: str):
|
def delete_file(self, file_id: str):
|
||||||
"""Delete a file from the db, including its chunks in docstore and vectorstore
|
"""Delete a file from the db, including its chunks in docstore and vectorstore
|
||||||
|
|
||||||
|
@ -398,44 +541,24 @@ class IndexPipeline(BaseComponent):
|
||||||
for each in index:
|
for each in index:
|
||||||
if each[0].relation_type == "vector":
|
if each[0].relation_type == "vector":
|
||||||
vs_ids.append(each[0].target_id)
|
vs_ids.append(each[0].target_id)
|
||||||
else:
|
elif each[0].relation_type == "document":
|
||||||
ds_ids.append(each[0].target_id)
|
ds_ids.append(each[0].target_id)
|
||||||
session.delete(each[0])
|
session.delete(each[0])
|
||||||
session.commit()
|
session.commit()
|
||||||
self.VS.delete(vs_ids)
|
|
||||||
self.DS.delete(ds_ids)
|
|
||||||
|
|
||||||
def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
|
if vs_ids and self.VS:
|
||||||
"""Index the file and return the file id"""
|
self.VS.delete(vs_ids)
|
||||||
# check for duplication
|
if ds_ids:
|
||||||
file_path = Path(file_path).resolve()
|
self.DS.delete(ds_ids)
|
||||||
file_id = self.get_id_if_exists(file_path)
|
|
||||||
if file_id is not None:
|
|
||||||
if not reindex:
|
|
||||||
raise ValueError(
|
|
||||||
f"File {file_path.name} already indexed. Please rerun with "
|
|
||||||
"reindex=True to force reindexing."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# remove the existing records
|
|
||||||
self.delete_file(file_id)
|
|
||||||
file_id = self.store_file(file_path)
|
|
||||||
else:
|
|
||||||
# add record to db
|
|
||||||
file_id = self.store_file(file_path)
|
|
||||||
|
|
||||||
# extract the file
|
def run(
|
||||||
extra_info = default_file_metadata_func(str(file_path))
|
self, file_path: str | Path, reindex: bool, **kwargs
|
||||||
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
) -> tuple[str, list[Document]]:
|
||||||
for _ in self.handle_docs(docs, file_id, file_path.name):
|
raise NotImplementedError
|
||||||
continue
|
|
||||||
self.finish(file_id, file_path)
|
|
||||||
|
|
||||||
return file_id
|
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self, file_path: str | Path, reindex: bool, **kwargs
|
self, file_path: str | Path, reindex: bool, **kwargs
|
||||||
) -> Generator[Document, None, str]:
|
) -> Generator[Document, None, tuple[str, list[Document]]]:
|
||||||
# check for duplication
|
# check for duplication
|
||||||
file_path = Path(file_path).resolve()
|
file_path = Path(file_path).resolve()
|
||||||
file_id = self.get_id_if_exists(file_path)
|
file_id = self.get_id_if_exists(file_path)
|
||||||
|
@ -456,6 +579,9 @@ class IndexPipeline(BaseComponent):
|
||||||
|
|
||||||
# extract the file
|
# extract the file
|
||||||
extra_info = default_file_metadata_func(str(file_path))
|
extra_info = default_file_metadata_func(str(file_path))
|
||||||
|
extra_info["file_id"] = file_id
|
||||||
|
extra_info["collection_name"] = self.collection_name
|
||||||
|
|
||||||
yield Document(f" => Converting {file_path.name} to text", channel="debug")
|
yield Document(f" => Converting {file_path.name} to text", channel="debug")
|
||||||
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
||||||
yield Document(f" => Converted {file_path.name} to text", channel="debug")
|
yield Document(f" => Converted {file_path.name} to text", channel="debug")
|
||||||
|
@ -464,7 +590,7 @@ class IndexPipeline(BaseComponent):
|
||||||
self.finish(file_id, file_path)
|
self.finish(file_id, file_path)
|
||||||
|
|
||||||
yield Document(f" => Finished indexing {file_path.name}", channel="debug")
|
yield Document(f" => Finished indexing {file_path.name}", channel="debug")
|
||||||
return file_id
|
return file_id, docs
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentPipeline(BaseFileIndexIndexing):
|
class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
|
@ -479,16 +605,54 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
decide which pipeline should be used.
|
decide which pipeline should be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
reader_mode: str = Param("default", help="The reader mode")
|
||||||
embedding: BaseEmbeddings
|
embedding: BaseEmbeddings
|
||||||
|
run_embedding_in_thread: bool = False
|
||||||
|
|
||||||
|
@Param.auto(depends_on="reader_mode")
|
||||||
|
def readers(self):
|
||||||
|
readers = deepcopy(KH_DEFAULT_FILE_EXTRACTORS)
|
||||||
|
print("reader_mode", self.reader_mode)
|
||||||
|
if self.reader_mode == "adobe":
|
||||||
|
readers[".pdf"] = adobe_reader
|
||||||
|
elif self.reader_mode == "azure-di":
|
||||||
|
readers[".pdf"] = azure_reader
|
||||||
|
|
||||||
|
dev_readers, _, _ = dev_settings()
|
||||||
|
readers.update(dev_readers)
|
||||||
|
|
||||||
|
return readers
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_user_settings(cls):
|
||||||
|
return {
|
||||||
|
"reader_mode": {
|
||||||
|
"name": "File loader",
|
||||||
|
"value": "default",
|
||||||
|
"choices": [
|
||||||
|
("Default (open-source)", "default"),
|
||||||
|
("Adobe API (figure+table extraction)", "adobe"),
|
||||||
|
(
|
||||||
|
"Azure AI Document Intelligence (figure+table extraction)",
|
||||||
|
"azure-di",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
"component": "dropdown",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
|
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
|
||||||
|
use_quick_index_mode = user_settings.get("quick_index_mode", False)
|
||||||
|
print("use_quick_index_mode", use_quick_index_mode)
|
||||||
obj = cls(
|
obj = cls(
|
||||||
embedding=embedding_models_manager[
|
embedding=embedding_models_manager[
|
||||||
index_settings.get(
|
index_settings.get(
|
||||||
"embedding", embedding_models_manager.get_default_name()
|
"embedding", embedding_models_manager.get_default_name()
|
||||||
)
|
)
|
||||||
]
|
],
|
||||||
|
run_embedding_in_thread=use_quick_index_mode,
|
||||||
|
reader_mode=user_settings.get("reader_mode", "default"),
|
||||||
)
|
)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
@ -497,16 +661,17 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
|
|
||||||
Can subclass this method for a more elaborate pipeline routing strategy.
|
Can subclass this method for a more elaborate pipeline routing strategy.
|
||||||
"""
|
"""
|
||||||
readers, chunk_size, chunk_overlap = dev_settings()
|
_, chunk_size, chunk_overlap = dev_settings()
|
||||||
|
|
||||||
ext = file_path.suffix
|
ext = file_path.suffix.lower()
|
||||||
reader = readers.get(ext, KH_DEFAULT_FILE_EXTRACTORS.get(ext, None))
|
reader = self.readers.get(ext, unstructured)
|
||||||
if reader is None:
|
if reader is None:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"No supported pipeline to index {file_path.name}. Please specify "
|
f"No supported pipeline to index {file_path.name}. Please specify "
|
||||||
"the suitable pipeline for this file type in the settings."
|
"the suitable pipeline for this file type in the settings."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print("Using reader", reader)
|
||||||
pipeline: IndexPipeline = IndexPipeline(
|
pipeline: IndexPipeline = IndexPipeline(
|
||||||
loader=reader,
|
loader=reader,
|
||||||
splitter=TokenSplitter(
|
splitter=TokenSplitter(
|
||||||
|
@ -515,50 +680,37 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
separator="\n\n",
|
separator="\n\n",
|
||||||
backup_separators=["\n", ".", "\u200B"],
|
backup_separators=["\n", ".", "\u200B"],
|
||||||
),
|
),
|
||||||
|
run_embedding_in_thread=self.run_embedding_in_thread,
|
||||||
Source=self.Source,
|
Source=self.Source,
|
||||||
Index=self.Index,
|
Index=self.Index,
|
||||||
VS=self.VS,
|
VS=self.VS,
|
||||||
DS=self.DS,
|
DS=self.DS,
|
||||||
FSPath=self.FSPath,
|
FSPath=self.FSPath,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
|
private=self.private,
|
||||||
embedding=self.embedding,
|
embedding=self.embedding,
|
||||||
)
|
)
|
||||||
|
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||||
) -> tuple[list[str | None], list[str | None]]:
|
) -> tuple[list[str | None], list[str | None]]:
|
||||||
"""Return a list of indexed file ids, and a list of errors"""
|
raise NotImplementedError
|
||||||
if not isinstance(file_paths, list):
|
|
||||||
file_paths = [file_paths]
|
|
||||||
|
|
||||||
file_ids: list[str | None] = []
|
|
||||||
errors: list[str | None] = []
|
|
||||||
for file_path in file_paths:
|
|
||||||
file_path = Path(file_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
pipeline = self.route(file_path)
|
|
||||||
file_id = pipeline.run(file_path, reindex=reindex, **kwargs)
|
|
||||||
file_ids.append(file_id)
|
|
||||||
errors.append(None)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(e)
|
|
||||||
file_ids.append(None)
|
|
||||||
errors.append(str(e))
|
|
||||||
|
|
||||||
return file_ids, errors
|
|
||||||
|
|
||||||
def stream(
|
def stream(
|
||||||
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||||
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
|
) -> Generator[
|
||||||
|
Document, None, tuple[list[str | None], list[str | None], list[Document]]
|
||||||
|
]:
|
||||||
"""Return a list of indexed file ids, and a list of errors"""
|
"""Return a list of indexed file ids, and a list of errors"""
|
||||||
if not isinstance(file_paths, list):
|
if not isinstance(file_paths, list):
|
||||||
file_paths = [file_paths]
|
file_paths = [file_paths]
|
||||||
|
|
||||||
file_ids: list[str | None] = []
|
file_ids: list[str | None] = []
|
||||||
errors: list[str | None] = []
|
errors: list[str | None] = []
|
||||||
|
all_docs = []
|
||||||
|
|
||||||
n_files = len(file_paths)
|
n_files = len(file_paths)
|
||||||
for idx, file_path in enumerate(file_paths):
|
for idx, file_path in enumerate(file_paths):
|
||||||
file_path = Path(file_path)
|
file_path = Path(file_path)
|
||||||
|
@ -569,9 +721,10 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pipeline = self.route(file_path)
|
pipeline = self.route(file_path)
|
||||||
file_id = yield from pipeline.stream(
|
file_id, docs = yield from pipeline.stream(
|
||||||
file_path, reindex=reindex, **kwargs
|
file_path, reindex=reindex, **kwargs
|
||||||
)
|
)
|
||||||
|
all_docs.extend(docs)
|
||||||
file_ids.append(file_id)
|
file_ids.append(file_id)
|
||||||
errors.append(None)
|
errors.append(None)
|
||||||
yield Document(
|
yield Document(
|
||||||
|
@ -579,7 +732,7 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
channel="index",
|
channel="index",
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.exception(e)
|
||||||
file_ids.append(None)
|
file_ids.append(None)
|
||||||
errors.append(str(e))
|
errors.append(str(e))
|
||||||
yield Document(
|
yield Document(
|
||||||
|
@ -591,4 +744,4 @@ class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
channel="index",
|
channel="index",
|
||||||
)
|
)
|
||||||
|
|
||||||
return file_ids, errors
|
return file_ids, errors, all_docs
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
|
import html
|
||||||
import os
|
import os
|
||||||
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import zipfile
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Generator
|
from typing import Generator
|
||||||
|
|
||||||
|
@ -9,8 +13,12 @@ from gradio.data_classes import FileData
|
||||||
from gradio.utils import NamedString
|
from gradio.utils import NamedString
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
from ktem.db.engine import engine
|
from ktem.db.engine import engine
|
||||||
|
from ktem.utils.render import Render
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
|
DOWNLOAD_MESSAGE = "Press again to download"
|
||||||
|
|
||||||
|
|
||||||
class File(gr.File):
|
class File(gr.File):
|
||||||
|
@ -143,28 +151,57 @@ class FileIndexPage(BasePage):
|
||||||
)
|
)
|
||||||
|
|
||||||
gr.Markdown("## File List")
|
gr.Markdown("## File List")
|
||||||
|
self.filter = gr.Textbox(
|
||||||
|
value="",
|
||||||
|
label="Filter by name:",
|
||||||
|
info=(
|
||||||
|
"(1) Case-insensitive. "
|
||||||
|
"(2) Search with empty string to show all files."
|
||||||
|
),
|
||||||
|
)
|
||||||
self.file_list_state = gr.State(value=None)
|
self.file_list_state = gr.State(value=None)
|
||||||
self.file_list = gr.DataFrame(
|
self.file_list = gr.DataFrame(
|
||||||
headers=["id", "name", "size", "text_length", "date_created"],
|
headers=[
|
||||||
|
"id",
|
||||||
|
"name",
|
||||||
|
"size",
|
||||||
|
"tokens",
|
||||||
|
"loader",
|
||||||
|
"date_created",
|
||||||
|
],
|
||||||
|
column_widths=["0%", "50%", "8%", "7%", "15%", "20%"],
|
||||||
interactive=False,
|
interactive=False,
|
||||||
|
wrap=False,
|
||||||
|
elem_id="file_list_view",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
self.deselect_button = gr.Button(
|
||||||
|
"Close",
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
self.delete_button = gr.Button(
|
||||||
|
"Delete",
|
||||||
|
variant="stop",
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
with gr.Row():
|
||||||
|
self.is_zipped_state = gr.State(value=False)
|
||||||
|
self.download_all_button = gr.DownloadButton(
|
||||||
|
"Download all files",
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
self.download_single_button = gr.DownloadButton(
|
||||||
|
"Download file",
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row() as self.selection_info:
|
with gr.Row() as self.selection_info:
|
||||||
self.selected_file_id = gr.State(value=None)
|
self.selected_file_id = gr.State(value=None)
|
||||||
with gr.Column(scale=2):
|
with gr.Column(scale=2):
|
||||||
self.selected_panel = gr.Markdown(self.selected_panel_false)
|
self.selected_panel = gr.Markdown(self.selected_panel_false)
|
||||||
|
|
||||||
self.deselect_button = gr.Button(
|
self.chunks = gr.HTML(visible=False)
|
||||||
"Deselect",
|
|
||||||
visible=False,
|
|
||||||
elem_classes=["right-button"],
|
|
||||||
)
|
|
||||||
self.delete_button = gr.Button(
|
|
||||||
"Delete",
|
|
||||||
variant="stop",
|
|
||||||
visible=False,
|
|
||||||
elem_classes=["right-button"],
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_subscribe_public_events(self):
|
def on_subscribe_public_events(self):
|
||||||
"""Subscribe to the declared public event of the app"""
|
"""Subscribe to the declared public event of the app"""
|
||||||
|
@ -189,12 +226,58 @@ class FileIndexPage(BasePage):
|
||||||
)
|
)
|
||||||
|
|
||||||
def file_selected(self, file_id):
|
def file_selected(self, file_id):
|
||||||
|
chunks = []
|
||||||
|
if file_id is not None:
|
||||||
|
# get the chunks
|
||||||
|
|
||||||
|
Index = self._index._resources["Index"]
|
||||||
|
with Session(engine) as session:
|
||||||
|
matches = session.execute(
|
||||||
|
select(Index).where(
|
||||||
|
Index.source_id == file_id,
|
||||||
|
Index.relation_type == "document",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
doc_ids = [doc.target_id for (doc,) in matches]
|
||||||
|
docs = self._index._docstore.get(doc_ids)
|
||||||
|
docs = sorted(
|
||||||
|
docs, key=lambda x: x.metadata.get("page_label", float("inf"))
|
||||||
|
)
|
||||||
|
|
||||||
|
for idx, doc in enumerate(docs):
|
||||||
|
title = html.escape(
|
||||||
|
f"{doc.text[:50]}..." if len(doc.text) > 50 else doc.text
|
||||||
|
)
|
||||||
|
doc_type = doc.metadata.get("type", "text")
|
||||||
|
content = ""
|
||||||
|
if doc_type == "text":
|
||||||
|
content = html.escape(doc.text)
|
||||||
|
elif doc_type == "table":
|
||||||
|
content = Render.table(doc.text)
|
||||||
|
elif doc_type == "image":
|
||||||
|
content = Render.image(
|
||||||
|
url=doc.metadata.get("image_origin", ""), text=doc.text
|
||||||
|
)
|
||||||
|
|
||||||
|
header_prefix = f"[{idx+1}/{len(docs)}]"
|
||||||
|
if doc.metadata.get("page_label"):
|
||||||
|
header_prefix += f" [Page {doc.metadata['page_label']}]"
|
||||||
|
|
||||||
|
chunks.append(
|
||||||
|
Render.collapsible(
|
||||||
|
header=f"{header_prefix} {title}",
|
||||||
|
content=content,
|
||||||
|
)
|
||||||
|
)
|
||||||
return (
|
return (
|
||||||
|
gr.update(value="".join(chunks), visible=file_id is not None),
|
||||||
|
gr.update(visible=file_id is not None),
|
||||||
gr.update(visible=file_id is not None),
|
gr.update(visible=file_id is not None),
|
||||||
gr.update(visible=file_id is not None),
|
gr.update(visible=file_id is not None),
|
||||||
)
|
)
|
||||||
|
|
||||||
def delete_event(self, file_id):
|
def delete_event(self, file_id):
|
||||||
|
file_name = ""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
source = session.execute(
|
source = session.execute(
|
||||||
select(self._index._resources["Source"]).where(
|
select(self._index._resources["Source"]).where(
|
||||||
|
@ -202,6 +285,7 @@ class FileIndexPage(BasePage):
|
||||||
)
|
)
|
||||||
).first()
|
).first()
|
||||||
if source:
|
if source:
|
||||||
|
file_name = source[0].name
|
||||||
session.delete(source[0])
|
session.delete(source[0])
|
||||||
|
|
||||||
vs_ids, ds_ids = [], []
|
vs_ids, ds_ids = [], []
|
||||||
|
@ -213,15 +297,16 @@ class FileIndexPage(BasePage):
|
||||||
for each in index:
|
for each in index:
|
||||||
if each[0].relation_type == "vector":
|
if each[0].relation_type == "vector":
|
||||||
vs_ids.append(each[0].target_id)
|
vs_ids.append(each[0].target_id)
|
||||||
else:
|
elif each[0].relation_type == "document":
|
||||||
ds_ids.append(each[0].target_id)
|
ds_ids.append(each[0].target_id)
|
||||||
session.delete(each[0])
|
session.delete(each[0])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
self._index._vs.delete(vs_ids)
|
if vs_ids:
|
||||||
|
self._index._vs.delete(vs_ids)
|
||||||
self._index._docstore.delete(ds_ids)
|
self._index._docstore.delete(ds_ids)
|
||||||
|
|
||||||
gr.Info(f"File {file_id} has been deleted")
|
gr.Info(f"File {file_name} has been deleted")
|
||||||
|
|
||||||
return None, self.selected_panel_false
|
return None, self.selected_panel_false
|
||||||
|
|
||||||
|
@ -231,6 +316,57 @@ class FileIndexPage(BasePage):
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def download_single_file(self, is_zipped_state, file_id):
|
||||||
|
with Session(engine) as session:
|
||||||
|
source = session.execute(
|
||||||
|
select(self._index._resources["Source"]).where(
|
||||||
|
self._index._resources["Source"].id == file_id
|
||||||
|
)
|
||||||
|
).first()
|
||||||
|
if source:
|
||||||
|
target_file_name = Path(source[0].name)
|
||||||
|
zip_files = []
|
||||||
|
for file_name in os.listdir(flowsettings.KH_CHUNKS_OUTPUT_DIR):
|
||||||
|
if target_file_name.stem in file_name:
|
||||||
|
zip_files.append(
|
||||||
|
os.path.join(flowsettings.KH_CHUNKS_OUTPUT_DIR, file_name)
|
||||||
|
)
|
||||||
|
for file_name in os.listdir(flowsettings.KH_MARKDOWN_OUTPUT_DIR):
|
||||||
|
if target_file_name.stem in file_name:
|
||||||
|
zip_files.append(
|
||||||
|
os.path.join(flowsettings.KH_MARKDOWN_OUTPUT_DIR, file_name)
|
||||||
|
)
|
||||||
|
zip_file_path = os.path.join(
|
||||||
|
flowsettings.KH_ZIP_OUTPUT_DIR, target_file_name.stem
|
||||||
|
)
|
||||||
|
with zipfile.ZipFile(f"{zip_file_path}.zip", "w") as zipMe:
|
||||||
|
for file in zip_files:
|
||||||
|
zipMe.write(file, arcname=os.path.basename(file))
|
||||||
|
|
||||||
|
if is_zipped_state:
|
||||||
|
new_button = gr.DownloadButton(label="Download", value=None)
|
||||||
|
else:
|
||||||
|
new_button = gr.DownloadButton(
|
||||||
|
label=DOWNLOAD_MESSAGE, value=f"{zip_file_path}.zip"
|
||||||
|
)
|
||||||
|
|
||||||
|
return not is_zipped_state, new_button
|
||||||
|
|
||||||
|
def download_all_files(self):
|
||||||
|
zip_files = []
|
||||||
|
for file_name in os.listdir(flowsettings.KH_CHUNKS_OUTPUT_DIR):
|
||||||
|
zip_files.append(os.path.join(flowsettings.KH_CHUNKS_OUTPUT_DIR, file_name))
|
||||||
|
for file_name in os.listdir(flowsettings.KH_MARKDOWN_OUTPUT_DIR):
|
||||||
|
zip_files.append(
|
||||||
|
os.path.join(flowsettings.KH_MARKDOWN_OUTPUT_DIR, file_name)
|
||||||
|
)
|
||||||
|
zip_file_path = os.path.join(flowsettings.KH_ZIP_OUTPUT_DIR, "all")
|
||||||
|
with zipfile.ZipFile(f"{zip_file_path}.zip", "w") as zipMe:
|
||||||
|
for file in zip_files:
|
||||||
|
arcname = Path(file)
|
||||||
|
zipMe.write(file, arcname=arcname.name)
|
||||||
|
return gr.DownloadButton(label=DOWNLOAD_MESSAGE, value=f"{zip_file_path}.zip")
|
||||||
|
|
||||||
def on_register_events(self):
|
def on_register_events(self):
|
||||||
"""Register all events to the app"""
|
"""Register all events to the app"""
|
||||||
onDeleted = (
|
onDeleted = (
|
||||||
|
@ -241,35 +377,61 @@ class FileIndexPage(BasePage):
|
||||||
)
|
)
|
||||||
.then(
|
.then(
|
||||||
fn=lambda: (None, self.selected_panel_false),
|
fn=lambda: (None, self.selected_panel_false),
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.selected_file_id, self.selected_panel],
|
outputs=[self.selected_file_id, self.selected_panel],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
.then(
|
.then(
|
||||||
fn=self.list_file,
|
fn=self.list_file,
|
||||||
inputs=[self._app.user_id],
|
inputs=[self._app.user_id, self.filter],
|
||||||
outputs=[self.file_list_state, self.file_list],
|
outputs=[self.file_list_state, self.file_list],
|
||||||
)
|
)
|
||||||
|
.then(
|
||||||
|
fn=self.file_selected,
|
||||||
|
inputs=[self.selected_file_id],
|
||||||
|
outputs=[
|
||||||
|
self.chunks,
|
||||||
|
self.deselect_button,
|
||||||
|
self.delete_button,
|
||||||
|
self.download_single_button,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
||||||
onDeleted = onDeleted.then(**event)
|
onDeleted = onDeleted.then(**event)
|
||||||
|
|
||||||
self.deselect_button.click(
|
self.deselect_button.click(
|
||||||
fn=lambda: (None, self.selected_panel_false),
|
fn=lambda: (None, self.selected_panel_false),
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.selected_file_id, self.selected_panel],
|
outputs=[self.selected_file_id, self.selected_panel],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
).then(
|
||||||
self.selected_panel.change(
|
|
||||||
fn=self.file_selected,
|
fn=self.file_selected,
|
||||||
inputs=[self.selected_file_id],
|
inputs=[self.selected_file_id],
|
||||||
outputs=[
|
outputs=[
|
||||||
|
self.chunks,
|
||||||
self.deselect_button,
|
self.deselect_button,
|
||||||
self.delete_button,
|
self.delete_button,
|
||||||
|
self.download_single_button,
|
||||||
],
|
],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.download_all_button.click(
|
||||||
|
fn=self.download_all_files,
|
||||||
|
inputs=[],
|
||||||
|
outputs=self.download_all_button,
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.download_single_button.click(
|
||||||
|
fn=self.download_single_file,
|
||||||
|
inputs=[self.is_zipped_state, self.selected_file_id],
|
||||||
|
outputs=[self.is_zipped_state, self.download_single_button],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
|
||||||
onUploaded = self.upload_button.click(
|
onUploaded = self.upload_button.click(
|
||||||
fn=lambda: gr.update(visible=True),
|
fn=lambda: gr.update(visible=True),
|
||||||
outputs=[self.upload_progress_panel],
|
outputs=[self.upload_progress_panel],
|
||||||
|
@ -285,9 +447,63 @@ class FileIndexPage(BasePage):
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# quick file upload event registration of first Index only
|
||||||
|
if self._index.id == 1:
|
||||||
|
self.quick_upload_state = gr.State(value=[])
|
||||||
|
print("Setting up quick upload event")
|
||||||
|
quickUploadedEvent = (
|
||||||
|
self._app.chat_page.quick_file_upload.upload(
|
||||||
|
fn=lambda: gr.update(
|
||||||
|
value="Please wait for the indexing process "
|
||||||
|
"to complete before adding your question."
|
||||||
|
),
|
||||||
|
outputs=self._app.chat_page.quick_file_upload_status,
|
||||||
|
)
|
||||||
|
.then(
|
||||||
|
fn=self.index_fn_with_default_loaders,
|
||||||
|
inputs=[
|
||||||
|
self._app.chat_page.quick_file_upload,
|
||||||
|
gr.State(value=False),
|
||||||
|
self._app.settings_state,
|
||||||
|
self._app.user_id,
|
||||||
|
],
|
||||||
|
outputs=self.quick_upload_state,
|
||||||
|
)
|
||||||
|
.success(
|
||||||
|
fn=lambda: [
|
||||||
|
gr.update(value=None),
|
||||||
|
gr.update(value="select"),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self._app.chat_page.quick_file_upload,
|
||||||
|
self._app.chat_page._indices_input[0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
||||||
|
quickUploadedEvent = quickUploadedEvent.then(**event)
|
||||||
|
|
||||||
|
quickUploadedEvent.success(
|
||||||
|
fn=lambda x: x,
|
||||||
|
inputs=self.quick_upload_state,
|
||||||
|
outputs=self._app.chat_page._indices_input[1],
|
||||||
|
).then(
|
||||||
|
fn=lambda: gr.update(value="Indexing completed."),
|
||||||
|
outputs=self._app.chat_page.quick_file_upload_status,
|
||||||
|
).then(
|
||||||
|
fn=self.list_file,
|
||||||
|
inputs=[self._app.user_id, self.filter],
|
||||||
|
outputs=[self.file_list_state, self.file_list],
|
||||||
|
concurrency_limit=20,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
uploadedEvent = onUploaded.then(
|
uploadedEvent = onUploaded.then(
|
||||||
fn=self.list_file,
|
fn=self.list_file,
|
||||||
inputs=[self._app.user_id],
|
inputs=[self._app.user_id, self.filter],
|
||||||
outputs=[self.file_list_state, self.file_list],
|
outputs=[self.file_list_state, self.file_list],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
)
|
)
|
||||||
|
@ -309,16 +525,64 @@ class FileIndexPage(BasePage):
|
||||||
inputs=[self.file_list],
|
inputs=[self.file_list],
|
||||||
outputs=[self.selected_file_id, self.selected_panel],
|
outputs=[self.selected_file_id, self.selected_panel],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=self.file_selected,
|
||||||
|
inputs=[self.selected_file_id],
|
||||||
|
outputs=[
|
||||||
|
self.chunks,
|
||||||
|
self.deselect_button,
|
||||||
|
self.delete_button,
|
||||||
|
self.download_single_button,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.filter.submit(
|
||||||
|
fn=self.list_file,
|
||||||
|
inputs=[self._app.user_id, self.filter],
|
||||||
|
outputs=[self.file_list_state, self.file_list],
|
||||||
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
def _on_app_created(self):
|
def _on_app_created(self):
|
||||||
"""Called when the app is created"""
|
"""Called when the app is created"""
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
self.list_file,
|
self.list_file,
|
||||||
inputs=[self._app.user_id],
|
inputs=[self._app.user_id, self.filter],
|
||||||
outputs=[self.file_list_state, self.file_list],
|
outputs=[self.file_list_state, self.file_list],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _may_extract_zip(self, files, zip_dir: str):
|
||||||
|
"""Handle zip files"""
|
||||||
|
zip_files = [file for file in files if file.endswith(".zip")]
|
||||||
|
remaining_files = [file for file in files if not file.endswith("zip")]
|
||||||
|
|
||||||
|
# Clean-up <zip_dir> before unzip to remove old files
|
||||||
|
shutil.rmtree(zip_dir, ignore_errors=True)
|
||||||
|
|
||||||
|
for zip_file in zip_files:
|
||||||
|
# Prepare new zip output dir, separated for each files
|
||||||
|
basename = os.path.splitext(os.path.basename(zip_file))[0]
|
||||||
|
zip_out_dir = os.path.join(zip_dir, basename)
|
||||||
|
os.makedirs(zip_out_dir, exist_ok=True)
|
||||||
|
with zipfile.ZipFile(zip_file, "r") as zip_ref:
|
||||||
|
zip_ref.extractall(zip_out_dir)
|
||||||
|
|
||||||
|
n_zip_file = 0
|
||||||
|
for root, dirs, files in os.walk(zip_dir):
|
||||||
|
for file in files:
|
||||||
|
ext = os.path.splitext(file)[1]
|
||||||
|
|
||||||
|
# only allow supported file-types ( not zip )
|
||||||
|
if ext not in [".zip"] and ext in self._supported_file_types:
|
||||||
|
remaining_files += [os.path.join(root, file)]
|
||||||
|
n_zip_file += 1
|
||||||
|
|
||||||
|
if n_zip_file > 0:
|
||||||
|
print(f"Update zip files: {n_zip_file}")
|
||||||
|
|
||||||
|
return remaining_files
|
||||||
|
|
||||||
def index_fn(
|
def index_fn(
|
||||||
self, files, reindex: bool, settings, user_id
|
self, files, reindex: bool, settings, user_id
|
||||||
) -> Generator[tuple[str, str], None, None]:
|
) -> Generator[tuple[str, str], None, None]:
|
||||||
|
@ -335,6 +599,8 @@ class FileIndexPage(BasePage):
|
||||||
yield "", ""
|
yield "", ""
|
||||||
return
|
return
|
||||||
|
|
||||||
|
files = self._may_extract_zip(files, flowsettings.KH_ZIP_INPUT_DIR)
|
||||||
|
|
||||||
errors = self.validate(files)
|
errors = self.validate(files)
|
||||||
if errors:
|
if errors:
|
||||||
gr.Warning(", ".join(errors))
|
gr.Warning(", ".join(errors))
|
||||||
|
@ -366,19 +632,61 @@ class FileIndexPage(BasePage):
|
||||||
debugs.append(response.text)
|
debugs.append(response.text)
|
||||||
yield "\n".join(outputs), "\n".join(debugs)
|
yield "\n".join(outputs), "\n".join(debugs)
|
||||||
except StopIteration as e:
|
except StopIteration as e:
|
||||||
result, errors = e.value
|
results, index_errors, docs = e.value
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
debugs.append(f"Error: {e}")
|
debugs.append(f"Error: {e}")
|
||||||
yield "\n".join(outputs), "\n".join(debugs)
|
yield "\n".join(outputs), "\n".join(debugs)
|
||||||
return
|
return
|
||||||
|
|
||||||
n_successes = len([_ for _ in result if _])
|
n_successes = len([_ for _ in results if _])
|
||||||
if n_successes:
|
if n_successes:
|
||||||
gr.Info(f"Successfully index {n_successes} files")
|
gr.Info(f"Successfully index {n_successes} files")
|
||||||
n_errors = len([_ for _ in errors if _])
|
n_errors = len([_ for _ in errors if _])
|
||||||
if n_errors:
|
if n_errors:
|
||||||
gr.Warning(f"Have errors for {n_errors} files")
|
gr.Warning(f"Have errors for {n_errors} files")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def index_fn_with_default_loaders(
|
||||||
|
self, files, reindex: bool, settings, user_id
|
||||||
|
) -> list["str"]:
|
||||||
|
"""Function for quick upload with default loaders
|
||||||
|
|
||||||
|
Args:
|
||||||
|
files: the list of files to be uploaded
|
||||||
|
reindex: whether to reindex the files
|
||||||
|
selected_files: the list of files already selected
|
||||||
|
settings: the settings of the app
|
||||||
|
"""
|
||||||
|
print("Overriding with default loaders")
|
||||||
|
exist_ids = []
|
||||||
|
to_process_files = []
|
||||||
|
for str_file_path in files:
|
||||||
|
file_path = Path(str(str_file_path))
|
||||||
|
exist_id = (
|
||||||
|
self._index.get_indexing_pipeline(settings, user_id)
|
||||||
|
.route(file_path)
|
||||||
|
.get_id_if_exists(file_path)
|
||||||
|
)
|
||||||
|
if exist_id:
|
||||||
|
exist_ids.append(exist_id)
|
||||||
|
else:
|
||||||
|
to_process_files.append(str_file_path)
|
||||||
|
|
||||||
|
returned_ids = []
|
||||||
|
settings = deepcopy(settings)
|
||||||
|
settings[f"index.options.{self._index.id}.reader_mode"] = "default"
|
||||||
|
settings[f"index.options.{self._index.id}.quick_index_mode"] = True
|
||||||
|
if to_process_files:
|
||||||
|
_iter = self.index_fn(to_process_files, reindex, settings, user_id)
|
||||||
|
try:
|
||||||
|
while next(_iter):
|
||||||
|
pass
|
||||||
|
except StopIteration as e:
|
||||||
|
returned_ids = e.value
|
||||||
|
|
||||||
|
return exist_ids + returned_ids
|
||||||
|
|
||||||
def index_files_from_dir(
|
def index_files_from_dir(
|
||||||
self, folder_path, reindex, settings, user_id
|
self, folder_path, reindex, settings, user_id
|
||||||
) -> Generator[tuple[str, str], None, None]:
|
) -> Generator[tuple[str, str], None, None]:
|
||||||
|
@ -452,7 +760,19 @@ class FileIndexPage(BasePage):
|
||||||
|
|
||||||
yield from self.index_fn(files, reindex, settings, user_id)
|
yield from self.index_fn(files, reindex, settings, user_id)
|
||||||
|
|
||||||
def list_file(self, user_id):
|
def format_size_human_readable(self, num: float | str, suffix="B"):
|
||||||
|
try:
|
||||||
|
num = float(num)
|
||||||
|
except ValueError:
|
||||||
|
return num
|
||||||
|
|
||||||
|
for unit in ("", "K", "M", "G", "T", "P", "E", "Z"):
|
||||||
|
if abs(num) < 1024.0:
|
||||||
|
return f"{num:3.0f}{unit}{suffix}"
|
||||||
|
num /= 1024.0
|
||||||
|
return f"{num:.0f}Yi{suffix}"
|
||||||
|
|
||||||
|
def list_file(self, user_id, name_pattern=""):
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
# not signed in
|
# not signed in
|
||||||
return [], pd.DataFrame.from_records(
|
return [], pd.DataFrame.from_records(
|
||||||
|
@ -461,7 +781,8 @@ class FileIndexPage(BasePage):
|
||||||
"id": "-",
|
"id": "-",
|
||||||
"name": "-",
|
"name": "-",
|
||||||
"size": "-",
|
"size": "-",
|
||||||
"text_length": "-",
|
"tokens": "-",
|
||||||
|
"loader": "-",
|
||||||
"date_created": "-",
|
"date_created": "-",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
@ -472,12 +793,17 @@ class FileIndexPage(BasePage):
|
||||||
statement = select(Source)
|
statement = select(Source)
|
||||||
if self._index.config.get("private", False):
|
if self._index.config.get("private", False):
|
||||||
statement = statement.where(Source.user == user_id)
|
statement = statement.where(Source.user == user_id)
|
||||||
|
if name_pattern:
|
||||||
|
statement = statement.where(Source.name.ilike(f"%{name_pattern}%"))
|
||||||
results = [
|
results = [
|
||||||
{
|
{
|
||||||
"id": each[0].id,
|
"id": each[0].id,
|
||||||
"name": each[0].name,
|
"name": each[0].name,
|
||||||
"size": each[0].size,
|
"size": self.format_size_human_readable(each[0].size),
|
||||||
"text_length": each[0].text_length,
|
"tokens": self.format_size_human_readable(
|
||||||
|
each[0].note.get("tokens", "-"), suffix=""
|
||||||
|
),
|
||||||
|
"loader": each[0].note.get("loader", "-"),
|
||||||
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
|
"date_created": each[0].date_created.strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
}
|
}
|
||||||
for each in session.execute(statement).all()
|
for each in session.execute(statement).all()
|
||||||
|
@ -492,12 +818,14 @@ class FileIndexPage(BasePage):
|
||||||
"id": "-",
|
"id": "-",
|
||||||
"name": "-",
|
"name": "-",
|
||||||
"size": "-",
|
"size": "-",
|
||||||
"text_length": "-",
|
"tokens": "-",
|
||||||
|
"loader": "-",
|
||||||
"date_created": "-",
|
"date_created": "-",
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
print(f"{len(results)=}, {len(file_list)=}")
|
||||||
return results, file_list
|
return results, file_list
|
||||||
|
|
||||||
def interact_file_list(self, list_files, ev: gr.SelectData):
|
def interact_file_list(self, list_files, ev: gr.SelectData):
|
||||||
|
@ -561,9 +889,8 @@ class FileSelector(BasePage):
|
||||||
self.mode = gr.Radio(
|
self.mode = gr.Radio(
|
||||||
value=default_mode,
|
value=default_mode,
|
||||||
choices=[
|
choices=[
|
||||||
("Disabled", "disabled"),
|
|
||||||
("Search All", "all"),
|
("Search All", "all"),
|
||||||
("Select", "select"),
|
("Search In File(s)", "select"),
|
||||||
],
|
],
|
||||||
container=False,
|
container=False,
|
||||||
)
|
)
|
||||||
|
|
|
@ -123,8 +123,11 @@ class IndexManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# clean up
|
try:
|
||||||
index.on_delete()
|
# clean up
|
||||||
|
index.on_delete()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while deleting index {index.name}: {e}")
|
||||||
|
|
||||||
# remove from database
|
# remove from database
|
||||||
with Session(engine) as sess:
|
with Session(engine) as sess:
|
||||||
|
|
|
@ -7,6 +7,21 @@ from ktem.utils.file import YAMLNoDateSafeLoader
|
||||||
from .manager import IndexManager
|
from .manager import IndexManager
|
||||||
|
|
||||||
|
|
||||||
|
# UGLY way to restart gradio server by updating atime
|
||||||
|
def update_current_module_atime():
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Define the file path
|
||||||
|
file_path = __file__
|
||||||
|
print("Updating atime for", file_path)
|
||||||
|
|
||||||
|
# Get the current time
|
||||||
|
current_time = time.time()
|
||||||
|
# Set the modified time (and access time) to the current time
|
||||||
|
os.utime(file_path, (current_time, current_time))
|
||||||
|
|
||||||
|
|
||||||
def format_description(cls):
|
def format_description(cls):
|
||||||
user_settings = cls.get_admin_settings()
|
user_settings = cls.get_admin_settings()
|
||||||
params_lines = ["| Name | Default | Description |", "| --- | --- | --- |"]
|
params_lines = ["| Name | Default | Description |", "| --- | --- | --- |"]
|
||||||
|
@ -29,7 +44,7 @@ class IndexManagement(BasePage):
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
with gr.Tab(label="View"):
|
with gr.Tab(label="View"):
|
||||||
self.index_list = gr.DataFrame(
|
self.index_list = gr.DataFrame(
|
||||||
headers=["ID", "Name", "Index Type"],
|
headers=["id", "name", "index type"],
|
||||||
interactive=False,
|
interactive=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -95,7 +110,7 @@ class IndexManagement(BasePage):
|
||||||
"""Called when the app is created"""
|
"""Called when the app is created"""
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
self.list_indices,
|
self.list_indices,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.index_list],
|
outputs=[self.index_list],
|
||||||
)
|
)
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
|
@ -117,7 +132,7 @@ class IndexManagement(BasePage):
|
||||||
self.create_index,
|
self.create_index,
|
||||||
inputs=[self.name, self.index_type, self.spec],
|
inputs=[self.name, self.index_type, self.spec],
|
||||||
outputs=None,
|
outputs=None,
|
||||||
).success(self.list_indices, inputs=None, outputs=[self.index_list]).success(
|
).success(self.list_indices, inputs=[], outputs=[self.index_list]).success(
|
||||||
lambda: ("", None, "", self.spec_desc_default),
|
lambda: ("", None, "", self.spec_desc_default),
|
||||||
outputs=[
|
outputs=[
|
||||||
self.name,
|
self.name,
|
||||||
|
@ -125,6 +140,8 @@ class IndexManagement(BasePage):
|
||||||
self.spec,
|
self.spec,
|
||||||
self.spec_desc,
|
self.spec_desc,
|
||||||
],
|
],
|
||||||
|
).success(
|
||||||
|
update_current_module_atime
|
||||||
)
|
)
|
||||||
self.index_list.select(
|
self.index_list.select(
|
||||||
self.select_index,
|
self.select_index,
|
||||||
|
@ -152,7 +169,7 @@ class IndexManagement(BasePage):
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
gr.update(visible=True),
|
gr.update(visible=True),
|
||||||
),
|
),
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[
|
outputs=[
|
||||||
self.btn_edit_save,
|
self.btn_edit_save,
|
||||||
self.btn_delete,
|
self.btn_delete,
|
||||||
|
@ -166,10 +183,8 @@ class IndexManagement(BasePage):
|
||||||
inputs=[self.selected_index_id],
|
inputs=[self.selected_index_id],
|
||||||
outputs=[self.selected_index_id],
|
outputs=[self.selected_index_id],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(self.list_indices, inputs=[], outputs=[self.index_list],).success(
|
||||||
self.list_indices,
|
update_current_module_atime
|
||||||
inputs=None,
|
|
||||||
outputs=[self.index_list],
|
|
||||||
)
|
)
|
||||||
self.btn_delete_no.click(
|
self.btn_delete_no.click(
|
||||||
lambda: (
|
lambda: (
|
||||||
|
@ -178,7 +193,7 @@ class IndexManagement(BasePage):
|
||||||
gr.update(visible=True),
|
gr.update(visible=True),
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
),
|
),
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[
|
outputs=[
|
||||||
self.btn_edit_save,
|
self.btn_edit_save,
|
||||||
self.btn_delete,
|
self.btn_delete,
|
||||||
|
@ -197,7 +212,7 @@ class IndexManagement(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.list_indices,
|
self.list_indices,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.index_list],
|
outputs=[self.index_list],
|
||||||
)
|
)
|
||||||
self.btn_close.click(
|
self.btn_close.click(
|
||||||
|
@ -245,16 +260,16 @@ class IndexManagement(BasePage):
|
||||||
items = []
|
items = []
|
||||||
for item in self.manager.indices:
|
for item in self.manager.indices:
|
||||||
record = {}
|
record = {}
|
||||||
record["ID"] = item.id
|
record["id"] = item.id
|
||||||
record["Name"] = item.name
|
record["name"] = item.name
|
||||||
record["Index Type"] = item.__class__.__name__
|
record["index type"] = item.__class__.__name__
|
||||||
items.append(record)
|
items.append(record)
|
||||||
|
|
||||||
if items:
|
if items:
|
||||||
indices_list = pd.DataFrame.from_records(items)
|
indices_list = pd.DataFrame.from_records(items)
|
||||||
else:
|
else:
|
||||||
indices_list = pd.DataFrame.from_records(
|
indices_list = pd.DataFrame.from_records(
|
||||||
[{"ID": "-", "Name": "-", "Index Type": "-"}]
|
[{"id": "-", "name": "-", "index type": "-"}]
|
||||||
)
|
)
|
||||||
|
|
||||||
return indices_list
|
return indices_list
|
||||||
|
@ -268,7 +283,7 @@ class IndexManagement(BasePage):
|
||||||
if not ev.selected:
|
if not ev.selected:
|
||||||
return -1
|
return -1
|
||||||
|
|
||||||
return int(index_list["ID"][ev.index[0]])
|
return int(index_list["id"][ev.index[0]])
|
||||||
|
|
||||||
def on_selected_index_change(self, selected_index_id: int):
|
def on_selected_index_change(self, selected_index_id: int):
|
||||||
"""Show the relevant index as user selects it on the UI
|
"""Show the relevant index as user selects it on the UI
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional, Type, overload
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
from theflow.utils.modules import deserialize
|
from theflow.utils.modules import deserialize, import_dotted_string
|
||||||
|
|
||||||
from kotaemon.llms import ChatLLM
|
from kotaemon.llms import ChatLLM
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ class LLMManager:
|
||||||
|
|
||||||
def load(self):
|
def load(self):
|
||||||
"""Load the model pool from database"""
|
"""Load the model pool from database"""
|
||||||
self._models, self._info, self._defaut = {}, {}, ""
|
self._models, self._info, self._default = {}, {}, ""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
stmt = select(LLMTable)
|
stmt = select(LLMTable)
|
||||||
items = session.execute(stmt)
|
items = session.execute(stmt)
|
||||||
|
@ -54,14 +54,12 @@ class LLMManager:
|
||||||
self._default = item.name
|
self._default = item.name
|
||||||
|
|
||||||
def load_vendors(self):
|
def load_vendors(self):
|
||||||
from kotaemon.llms import (
|
from kotaemon.llms import AzureChatOpenAI, ChatOpenAI, LlamaCppChat
|
||||||
AzureChatOpenAI,
|
|
||||||
ChatOpenAI,
|
|
||||||
EndpointChatLLM,
|
|
||||||
LlamaCppChat,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
|
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat]
|
||||||
|
|
||||||
|
for extra_vendor in getattr(flowsettings, "KH_LLM_EXTRA_VENDORS", []):
|
||||||
|
self._vendors.append(import_dotted_string(extra_vendor, safe=False))
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> ChatLLM:
|
def __getitem__(self, key: str) -> ChatLLM:
|
||||||
"""Get model by name"""
|
"""Get model by name"""
|
||||||
|
|
|
@ -112,7 +112,7 @@ class LLMManagement(BasePage):
|
||||||
"""Called when the app is created"""
|
"""Called when the app is created"""
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
self.list_llms,
|
self.list_llms,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.llm_list],
|
outputs=[self.llm_list],
|
||||||
)
|
)
|
||||||
self._app.app.load(
|
self._app.app.load(
|
||||||
|
@ -140,8 +140,8 @@ class LLMManagement(BasePage):
|
||||||
self.btn_new.click(
|
self.btn_new.click(
|
||||||
self.create_llm,
|
self.create_llm,
|
||||||
inputs=[self.name, self.llm_choices, self.spec, self.default],
|
inputs=[self.name, self.llm_choices, self.spec, self.default],
|
||||||
outputs=None,
|
outputs=[],
|
||||||
).success(self.list_llms, inputs=None, outputs=[self.llm_list]).success(
|
).success(self.list_llms, inputs=[], outputs=[self.llm_list]).success(
|
||||||
lambda: ("", None, "", False, self.spec_desc_default),
|
lambda: ("", None, "", False, self.spec_desc_default),
|
||||||
outputs=[
|
outputs=[
|
||||||
self.name,
|
self.name,
|
||||||
|
@ -176,7 +176,7 @@ class LLMManagement(BasePage):
|
||||||
)
|
)
|
||||||
self.btn_delete.click(
|
self.btn_delete.click(
|
||||||
self.on_btn_delete_click,
|
self.on_btn_delete_click,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
@ -187,7 +187,7 @@ class LLMManagement(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.list_llms,
|
self.list_llms,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.llm_list],
|
outputs=[self.llm_list],
|
||||||
)
|
)
|
||||||
self.btn_delete_no.click(
|
self.btn_delete_no.click(
|
||||||
|
@ -196,7 +196,7 @@ class LLMManagement(BasePage):
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
gr.update(visible=False),
|
gr.update(visible=False),
|
||||||
),
|
),
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
@ -210,7 +210,7 @@ class LLMManagement(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.list_llms,
|
self.list_llms,
|
||||||
inputs=None,
|
inputs=[],
|
||||||
outputs=[self.llm_list],
|
outputs=[self.llm_list],
|
||||||
)
|
)
|
||||||
self.btn_close.click(
|
self.btn_close.click(
|
||||||
|
|
|
@ -44,7 +44,7 @@ class App(BaseApp):
|
||||||
if len(self.index_manager.indices) == 1:
|
if len(self.index_manager.indices) == 1:
|
||||||
for index in self.index_manager.indices:
|
for index in self.index_manager.indices:
|
||||||
with gr.Tab(
|
with gr.Tab(
|
||||||
f"{index.name} Index",
|
f"{index.name}",
|
||||||
elem_id="indices-tab",
|
elem_id="indices-tab",
|
||||||
elem_classes=[
|
elem_classes=[
|
||||||
"fill-main-area-height",
|
"fill-main-area-height",
|
||||||
|
@ -58,7 +58,7 @@ class App(BaseApp):
|
||||||
setattr(self, f"_index_{index.id}", page)
|
setattr(self, f"_index_{index.id}", page)
|
||||||
elif len(self.index_manager.indices) > 1:
|
elif len(self.index_manager.indices) > 1:
|
||||||
with gr.Tab(
|
with gr.Tab(
|
||||||
"Indices",
|
"Files",
|
||||||
elem_id="indices-tab",
|
elem_id="indices-tab",
|
||||||
elem_classes=["fill-main-area-height", "scrollable", "indices-tab"],
|
elem_classes=["fill-main-area-height", "scrollable", "indices-tab"],
|
||||||
id="indices-tab",
|
id="indices-tab",
|
||||||
|
@ -66,7 +66,7 @@ class App(BaseApp):
|
||||||
) as self._tabs["indices-tab"]:
|
) as self._tabs["indices-tab"]:
|
||||||
for index in self.index_manager.indices:
|
for index in self.index_manager.indices:
|
||||||
with gr.Tab(
|
with gr.Tab(
|
||||||
f"{index.name}",
|
f"{index.name} Collection",
|
||||||
elem_id=f"{index.id}-tab",
|
elem_id=f"{index.id}-tab",
|
||||||
) as self._tabs[f"{index.id}-tab"]:
|
) as self._tabs[f"{index.id}-tab"]:
|
||||||
page = index.get_index_page_ui()
|
page = index.get_index_page_ui()
|
||||||
|
|
|
@ -1,15 +1,25 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import csv
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
from filelock import FileLock
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
from ktem.components import reasonings
|
from ktem.components import reasonings
|
||||||
from ktem.db.models import Conversation, engine
|
from ktem.db.models import Conversation, engine
|
||||||
|
from ktem.index.file.ui import File
|
||||||
|
from ktem.reasoning.prompt_optimization.suggest_conversation_name import (
|
||||||
|
SuggestConvNamePipeline,
|
||||||
|
)
|
||||||
|
from plotly.io import from_json
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, select
|
||||||
from theflow.settings import settings as flowsettings
|
from theflow.settings import settings as flowsettings
|
||||||
|
|
||||||
from kotaemon.base import Document
|
from kotaemon.base import Document
|
||||||
|
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||||
|
|
||||||
from .chat_panel import ChatPanel
|
from .chat_panel import ChatPanel
|
||||||
from .chat_suggestion import ChatSuggestion
|
from .chat_suggestion import ChatSuggestion
|
||||||
|
@ -17,23 +27,49 @@ from .common import STATE
|
||||||
from .control import ConversationControl
|
from .control import ConversationControl
|
||||||
from .report import ReportIssue
|
from .report import ReportIssue
|
||||||
|
|
||||||
|
DEFAULT_SETTING = "(default)"
|
||||||
|
INFO_PANEL_SCALES = {True: 8, False: 4}
|
||||||
|
|
||||||
|
|
||||||
|
pdfview_js = """
|
||||||
|
function() {
|
||||||
|
// Get all links and attach click event
|
||||||
|
var links = document.getElementsByClassName("pdf-link");
|
||||||
|
for (var i = 0; i < links.length; i++) {
|
||||||
|
links[i].onclick = openModal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class ChatPage(BasePage):
|
class ChatPage(BasePage):
|
||||||
def __init__(self, app):
|
def __init__(self, app):
|
||||||
self._app = app
|
self._app = app
|
||||||
self._indices_input = []
|
self._indices_input = []
|
||||||
|
|
||||||
self.on_building_ui()
|
self.on_building_ui()
|
||||||
|
self._reasoning_type = gr.State(value=None)
|
||||||
|
self._llm_type = gr.State(value=None)
|
||||||
|
self._conversation_renamed = gr.State(value=False)
|
||||||
|
self.info_panel_expanded = gr.State(value=True)
|
||||||
|
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.chat_state = gr.State(STATE)
|
self.state_chat = gr.State(STATE)
|
||||||
with gr.Column(scale=1, elem_id="conv-settings-panel"):
|
self.state_retrieval_history = gr.State([])
|
||||||
|
self.state_chat_history = gr.State([])
|
||||||
|
self.state_plot_history = gr.State([])
|
||||||
|
self.state_settings = gr.State({})
|
||||||
|
self.state_info_panel = gr.State("")
|
||||||
|
self.state_plot_panel = gr.State(None)
|
||||||
|
|
||||||
|
with gr.Column(scale=1, elem_id="conv-settings-panel") as self.conv_column:
|
||||||
self.chat_control = ConversationControl(self._app)
|
self.chat_control = ConversationControl(self._app)
|
||||||
|
|
||||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||||
self.chat_suggestion = ChatSuggestion(self._app)
|
self.chat_suggestion = ChatSuggestion(self._app)
|
||||||
|
|
||||||
for index in self._app.index_manager.indices:
|
for index_id, index in enumerate(self._app.index_manager.indices):
|
||||||
index.selector = None
|
index.selector = None
|
||||||
index_ui = index.get_selector_component_ui()
|
index_ui = index.get_selector_component_ui()
|
||||||
if not index_ui:
|
if not index_ui:
|
||||||
|
@ -41,7 +77,9 @@ class ChatPage(BasePage):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
index_ui.unrender() # need to rerender later within Accordion
|
index_ui.unrender() # need to rerender later within Accordion
|
||||||
with gr.Accordion(label=f"{index.name} Index", open=True):
|
with gr.Accordion(
|
||||||
|
label=f"{index.name} Collection", open=index_id < 1
|
||||||
|
):
|
||||||
index_ui.render()
|
index_ui.render()
|
||||||
gr_index = index_ui.as_gradio_component()
|
gr_index = index_ui.as_gradio_component()
|
||||||
if gr_index:
|
if gr_index:
|
||||||
|
@ -60,14 +98,66 @@ class ChatPage(BasePage):
|
||||||
self._indices_input.append(gr_index)
|
self._indices_input.append(gr_index)
|
||||||
setattr(self, f"_index_{index.id}", index_ui)
|
setattr(self, f"_index_{index.id}", index_ui)
|
||||||
|
|
||||||
|
if len(self._app.index_manager.indices) > 0:
|
||||||
|
with gr.Accordion(label="Quick Upload") as _:
|
||||||
|
self.quick_file_upload = File(
|
||||||
|
file_types=list(KH_DEFAULT_FILE_EXTRACTORS.keys()),
|
||||||
|
file_count="multiple",
|
||||||
|
container=True,
|
||||||
|
show_label=False,
|
||||||
|
)
|
||||||
|
self.quick_file_upload_status = gr.Markdown()
|
||||||
|
|
||||||
self.report_issue = ReportIssue(self._app)
|
self.report_issue = ReportIssue(self._app)
|
||||||
|
|
||||||
with gr.Column(scale=6, elem_id="chat-area"):
|
with gr.Column(scale=6, elem_id="chat-area"):
|
||||||
self.chat_panel = ChatPanel(self._app)
|
self.chat_panel = ChatPanel(self._app)
|
||||||
|
|
||||||
with gr.Column(scale=3, elem_id="chat-info-panel"):
|
with gr.Row():
|
||||||
|
with gr.Accordion(label="Chat settings", open=False):
|
||||||
|
# a quick switch for reasoning type option
|
||||||
|
with gr.Row():
|
||||||
|
gr.HTML("Reasoning method")
|
||||||
|
gr.HTML("Model")
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
reasoning_type_values = [
|
||||||
|
(DEFAULT_SETTING, DEFAULT_SETTING)
|
||||||
|
] + self._app.default_settings.reasoning.settings[
|
||||||
|
"use"
|
||||||
|
].choices
|
||||||
|
self.reasoning_types = gr.Dropdown(
|
||||||
|
choices=reasoning_type_values,
|
||||||
|
value=DEFAULT_SETTING,
|
||||||
|
container=False,
|
||||||
|
show_label=False,
|
||||||
|
)
|
||||||
|
self.model_types = gr.Dropdown(
|
||||||
|
choices=self._app.default_settings.reasoning.options[
|
||||||
|
"simple"
|
||||||
|
]
|
||||||
|
.settings["llm"]
|
||||||
|
.choices,
|
||||||
|
value="",
|
||||||
|
container=False,
|
||||||
|
show_label=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with gr.Column(
|
||||||
|
scale=INFO_PANEL_SCALES[False], elem_id="chat-info-panel"
|
||||||
|
) as self.info_column:
|
||||||
with gr.Accordion(label="Information panel", open=True):
|
with gr.Accordion(label="Information panel", open=True):
|
||||||
self.info_panel = gr.HTML()
|
self.modal = gr.HTML("<div id='pdf-modal'></div>")
|
||||||
|
self.plot_panel = gr.Plot(visible=False)
|
||||||
|
self.info_panel = gr.HTML(elem_id="html-info-panel")
|
||||||
|
|
||||||
|
def _json_to_plot(self, json_dict: dict | None):
|
||||||
|
if json_dict:
|
||||||
|
plot = from_json(json_dict)
|
||||||
|
plot = gr.update(visible=True, value=plot)
|
||||||
|
else:
|
||||||
|
plot = gr.update(visible=False)
|
||||||
|
return plot
|
||||||
|
|
||||||
def on_register_events(self):
|
def on_register_events(self):
|
||||||
gr.on(
|
gr.on(
|
||||||
|
@ -98,27 +188,75 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self.chat_state,
|
self._reasoning_type,
|
||||||
|
self._llm_type,
|
||||||
|
self.state_chat,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.chat_state,
|
self.plot_panel,
|
||||||
|
self.state_plot_panel,
|
||||||
|
self.state_chat,
|
||||||
],
|
],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
show_progress="minimal",
|
show_progress="minimal",
|
||||||
|
).success(
|
||||||
|
fn=self.backup_original_info,
|
||||||
|
inputs=[
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self._app.settings_state,
|
||||||
|
self.info_panel,
|
||||||
|
self.state_chat_history,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.state_chat_history,
|
||||||
|
self.state_settings,
|
||||||
|
self.state_info_panel,
|
||||||
|
],
|
||||||
).then(
|
).then(
|
||||||
fn=self.update_data_source,
|
fn=self.persist_data_source,
|
||||||
inputs=[
|
inputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
|
self._app.user_id,
|
||||||
|
self.info_panel,
|
||||||
|
self.state_plot_panel,
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.chat_state,
|
self.state_chat,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=None,
|
outputs=[
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
|
).success(
|
||||||
|
fn=self.check_and_suggest_name_conv,
|
||||||
|
inputs=self.chat_panel.chatbot,
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
self._conversation_renamed,
|
||||||
|
],
|
||||||
|
).success(
|
||||||
|
self.chat_control.rename_conv,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
self._conversation_renamed,
|
||||||
|
self._app.user_id,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=None, inputs=None, outputs=None, js=pdfview_js
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chat_panel.regen_btn.click(
|
self.chat_panel.regen_btn.click(
|
||||||
|
@ -127,33 +265,90 @@ class ChatPage(BasePage):
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self.chat_state,
|
self._reasoning_type,
|
||||||
|
self._llm_type,
|
||||||
|
self.state_chat,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.chat_state,
|
self.plot_panel,
|
||||||
|
self.state_plot_panel,
|
||||||
|
self.state_chat,
|
||||||
],
|
],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
show_progress="minimal",
|
show_progress="minimal",
|
||||||
).then(
|
).then(
|
||||||
fn=self.update_data_source,
|
fn=self.persist_data_source,
|
||||||
inputs=[
|
inputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
|
self._app.user_id,
|
||||||
|
self.info_panel,
|
||||||
|
self.state_plot_panel,
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.chat_state,
|
self.state_chat,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=None,
|
outputs=[
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
|
).success(
|
||||||
|
fn=self.check_and_suggest_name_conv,
|
||||||
|
inputs=self.chat_panel.chatbot,
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
self._conversation_renamed,
|
||||||
|
],
|
||||||
|
).success(
|
||||||
|
self.chat_control.rename_conv,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
self._conversation_renamed,
|
||||||
|
self._app.user_id,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
],
|
||||||
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=None, inputs=None, outputs=None, js=pdfview_js
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chat_control.btn_info_expand.click(
|
||||||
|
fn=lambda is_expanded: (
|
||||||
|
gr.update(scale=INFO_PANEL_SCALES[is_expanded]),
|
||||||
|
not is_expanded,
|
||||||
|
),
|
||||||
|
inputs=self.info_panel_expanded,
|
||||||
|
outputs=[self.info_column, self.info_panel_expanded],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chat_panel.chatbot.like(
|
self.chat_panel.chatbot.like(
|
||||||
fn=self.is_liked,
|
fn=self.is_liked,
|
||||||
inputs=[self.chat_control.conversation_id],
|
inputs=[self.chat_control.conversation_id],
|
||||||
outputs=None,
|
outputs=None,
|
||||||
|
).success(
|
||||||
|
self.save_log,
|
||||||
|
inputs=[
|
||||||
|
self.chat_control.conversation_id,
|
||||||
|
self.chat_panel.chatbot,
|
||||||
|
self._app.settings_state,
|
||||||
|
self.info_panel,
|
||||||
|
self.state_chat_history,
|
||||||
|
self.state_settings,
|
||||||
|
self.state_info_panel,
|
||||||
|
gr.State(getattr(flowsettings, "KH_APP_DATA_DIR", "logs")),
|
||||||
|
],
|
||||||
|
outputs=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chat_control.btn_new.click(
|
self.chat_control.btn_new.click(
|
||||||
|
@ -163,17 +358,25 @@ class ChatPage(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.chat_control.select_conv,
|
self.chat_control.select_conv,
|
||||||
inputs=[self.chat_control.conversation],
|
inputs=[self.chat_control.conversation, self._app.user_id],
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.chat_state,
|
self.state_plot_panel,
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
self.chat_control.cb_is_public,
|
||||||
|
self.state_chat,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=self._json_to_plot,
|
||||||
|
inputs=self.state_plot_panel,
|
||||||
|
outputs=self.plot_panel,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chat_control.btn_del.click(
|
self.chat_control.btn_del.click(
|
||||||
|
@ -188,17 +391,25 @@ class ChatPage(BasePage):
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
).then(
|
).then(
|
||||||
self.chat_control.select_conv,
|
self.chat_control.select_conv,
|
||||||
inputs=[self.chat_control.conversation],
|
inputs=[self.chat_control.conversation, self._app.user_id],
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.chat_state,
|
self.state_plot_panel,
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
self.chat_control.cb_is_public,
|
||||||
|
self.state_chat,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=self._json_to_plot,
|
||||||
|
inputs=self.state_plot_panel,
|
||||||
|
outputs=self.plot_panel,
|
||||||
).then(
|
).then(
|
||||||
lambda: self.toggle_delete(""),
|
lambda: self.toggle_delete(""),
|
||||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||||
|
@ -207,33 +418,80 @@ class ChatPage(BasePage):
|
||||||
lambda: self.toggle_delete(""),
|
lambda: self.toggle_delete(""),
|
||||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||||
)
|
)
|
||||||
self.chat_control.conversation_rn_btn.click(
|
self.chat_control.btn_conversation_rn.click(
|
||||||
|
lambda: gr.update(visible=True),
|
||||||
|
outputs=[
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
self.chat_control.conversation_rn.submit(
|
||||||
self.chat_control.rename_conv,
|
self.chat_control.rename_conv,
|
||||||
inputs=[
|
inputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
|
gr.State(value=True),
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
],
|
],
|
||||||
outputs=[self.chat_control.conversation, self.chat_control.conversation],
|
outputs=[
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation,
|
||||||
|
self.chat_control.conversation_rn,
|
||||||
|
],
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.chat_control.conversation.select(
|
self.chat_control.conversation.select(
|
||||||
self.chat_control.select_conv,
|
self.chat_control.select_conv,
|
||||||
inputs=[self.chat_control.conversation],
|
inputs=[self.chat_control.conversation, self._app.user_id],
|
||||||
outputs=[
|
outputs=[
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.chat_state,
|
self.state_plot_panel,
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
self.chat_control.cb_is_public,
|
||||||
|
self.state_chat,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
show_progress="hidden",
|
show_progress="hidden",
|
||||||
|
).then(
|
||||||
|
fn=self._json_to_plot,
|
||||||
|
inputs=self.state_plot_panel,
|
||||||
|
outputs=self.plot_panel,
|
||||||
).then(
|
).then(
|
||||||
lambda: self.toggle_delete(""),
|
lambda: self.toggle_delete(""),
|
||||||
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
outputs=[self.chat_control._new_delete, self.chat_control._delete_confirm],
|
||||||
|
).then(
|
||||||
|
fn=None, inputs=None, outputs=None, js=pdfview_js
|
||||||
|
)
|
||||||
|
|
||||||
|
# evidence display on message selection
|
||||||
|
self.chat_panel.chatbot.select(
|
||||||
|
self.message_selected,
|
||||||
|
inputs=[
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
self.info_panel,
|
||||||
|
self.state_plot_panel,
|
||||||
|
],
|
||||||
|
).then(
|
||||||
|
fn=self._json_to_plot,
|
||||||
|
inputs=self.state_plot_panel,
|
||||||
|
outputs=self.plot_panel,
|
||||||
|
).then(
|
||||||
|
fn=None, inputs=None, outputs=None, js=pdfview_js
|
||||||
|
)
|
||||||
|
|
||||||
|
self.chat_control.cb_is_public.change(
|
||||||
|
self.on_set_public_conversation,
|
||||||
|
inputs=[self.chat_control.cb_is_public, self.chat_control.conversation],
|
||||||
|
outputs=None,
|
||||||
|
show_progress="hidden",
|
||||||
)
|
)
|
||||||
|
|
||||||
self.report_issue.report_btn.click(
|
self.report_issue.report_btn.click(
|
||||||
|
@ -247,11 +505,26 @@ class ChatPage(BasePage):
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
self.chat_state,
|
self.state_chat,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=None,
|
outputs=None,
|
||||||
)
|
)
|
||||||
|
self.reasoning_types.change(
|
||||||
|
self.reasoning_changed,
|
||||||
|
inputs=[self.reasoning_types],
|
||||||
|
outputs=[self._reasoning_type],
|
||||||
|
)
|
||||||
|
self.model_types.change(
|
||||||
|
lambda x: x,
|
||||||
|
inputs=[self.model_types],
|
||||||
|
outputs=[self._llm_type],
|
||||||
|
)
|
||||||
|
self.chat_control.conversation_id.change(
|
||||||
|
lambda: gr.update(visible=False),
|
||||||
|
outputs=self.plot_panel,
|
||||||
|
)
|
||||||
|
|
||||||
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
if getattr(flowsettings, "KH_FEATURE_CHAT_SUGGESTION", False):
|
||||||
self.chat_suggestion.example.select(
|
self.chat_suggestion.example.select(
|
||||||
self.chat_suggestion.select_example,
|
self.chat_suggestion.select_example,
|
||||||
|
@ -291,6 +564,28 @@ class ChatPage(BasePage):
|
||||||
else:
|
else:
|
||||||
return gr.update(visible=True), gr.update(visible=False)
|
return gr.update(visible=True), gr.update(visible=False)
|
||||||
|
|
||||||
|
def on_set_public_conversation(self, is_public, convo_id):
|
||||||
|
if not convo_id:
|
||||||
|
gr.Warning("No conversation selected")
|
||||||
|
return
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||||
|
|
||||||
|
result = session.exec(statement).one()
|
||||||
|
name = result.name
|
||||||
|
|
||||||
|
if result.is_public != is_public:
|
||||||
|
# Only trigger updating when user
|
||||||
|
# select different value from the current
|
||||||
|
result.is_public = is_public
|
||||||
|
session.add(result)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
gr.Info(
|
||||||
|
f"Conversation: {name} is {'public' if is_public else 'private'}."
|
||||||
|
)
|
||||||
|
|
||||||
def on_subscribe_public_events(self):
|
def on_subscribe_public_events(self):
|
||||||
if self._app.f_user_management:
|
if self._app.f_user_management:
|
||||||
self._app.subscribe_event(
|
self._app.subscribe_event(
|
||||||
|
@ -306,25 +601,53 @@ class ChatPage(BasePage):
|
||||||
self._app.subscribe_event(
|
self._app.subscribe_event(
|
||||||
name="onSignOut",
|
name="onSignOut",
|
||||||
definition={
|
definition={
|
||||||
"fn": lambda: self.chat_control.select_conv(""),
|
"fn": lambda: self.chat_control.select_conv("", None),
|
||||||
"outputs": [
|
"outputs": [
|
||||||
self.chat_control.conversation_id,
|
self.chat_control.conversation_id,
|
||||||
self.chat_control.conversation,
|
self.chat_control.conversation,
|
||||||
self.chat_control.conversation_rn,
|
self.chat_control.conversation_rn,
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self.info_panel,
|
self.info_panel,
|
||||||
|
self.state_plot_panel,
|
||||||
|
self.state_retrieval_history,
|
||||||
|
self.state_plot_history,
|
||||||
|
self.chat_control.cb_is_public,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
"show_progress": "hidden",
|
"show_progress": "hidden",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def update_data_source(self, convo_id, messages, state, *selecteds):
|
def persist_data_source(
|
||||||
|
self,
|
||||||
|
convo_id,
|
||||||
|
user_id,
|
||||||
|
retrieval_msg,
|
||||||
|
plot_data,
|
||||||
|
retrival_history,
|
||||||
|
plot_history,
|
||||||
|
messages,
|
||||||
|
state,
|
||||||
|
*selecteds,
|
||||||
|
):
|
||||||
"""Update the data source"""
|
"""Update the data source"""
|
||||||
if not convo_id:
|
if not convo_id:
|
||||||
gr.Warning("No conversation selected")
|
gr.Warning("No conversation selected")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# if not regen, then append the new message
|
||||||
|
if not state["app"].get("regen", False):
|
||||||
|
retrival_history = retrival_history + [retrieval_msg]
|
||||||
|
plot_history = plot_history + [plot_data]
|
||||||
|
else:
|
||||||
|
if retrival_history:
|
||||||
|
print("Updating retrieval history (regen=True)")
|
||||||
|
retrival_history[-1] = retrieval_msg
|
||||||
|
plot_history[-1] = plot_data
|
||||||
|
|
||||||
|
# reset regen state
|
||||||
|
state["app"]["regen"] = False
|
||||||
|
|
||||||
selecteds_ = {}
|
selecteds_ = {}
|
||||||
for index in self._app.index_manager.indices:
|
for index in self._app.index_manager.indices:
|
||||||
if index.selector is None:
|
if index.selector is None:
|
||||||
|
@ -339,15 +662,29 @@ class ChatPage(BasePage):
|
||||||
result = session.exec(statement).one()
|
result = session.exec(statement).one()
|
||||||
|
|
||||||
data_source = result.data_source
|
data_source = result.data_source
|
||||||
|
old_selecteds = data_source.get("selected", {})
|
||||||
|
is_owner = result.user == user_id
|
||||||
|
|
||||||
|
# Write down to db
|
||||||
result.data_source = {
|
result.data_source = {
|
||||||
"selected": selecteds_,
|
"selected": selecteds_ if is_owner else old_selecteds,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
"retrieval_messages": retrival_history,
|
||||||
|
"plot_history": plot_history,
|
||||||
"state": state,
|
"state": state,
|
||||||
"likes": deepcopy(data_source.get("likes", [])),
|
"likes": deepcopy(data_source.get("likes", [])),
|
||||||
}
|
}
|
||||||
session.add(result)
|
session.add(result)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
return retrival_history, plot_history
|
||||||
|
|
||||||
|
def reasoning_changed(self, reasoning_type):
|
||||||
|
if reasoning_type != DEFAULT_SETTING:
|
||||||
|
# override app settings state (temporary)
|
||||||
|
gr.Info("Reasoning type changed to `{}`".format(reasoning_type))
|
||||||
|
return reasoning_type
|
||||||
|
|
||||||
def is_liked(self, convo_id, liked: gr.LikeData):
|
def is_liked(self, convo_id, liked: gr.LikeData):
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
statement = select(Conversation).where(Conversation.id == convo_id)
|
statement = select(Conversation).where(Conversation.id == convo_id)
|
||||||
|
@ -362,7 +699,19 @@ class ChatPage(BasePage):
|
||||||
session.add(result)
|
session.add(result)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def create_pipeline(self, settings: dict, state: dict, user_id: int, *selecteds):
|
def message_selected(self, retrieval_history, plot_history, msg: gr.SelectData):
|
||||||
|
index = msg.index[0]
|
||||||
|
return retrieval_history[index], plot_history[index]
|
||||||
|
|
||||||
|
def create_pipeline(
|
||||||
|
self,
|
||||||
|
settings: dict,
|
||||||
|
session_reasoning_type: str,
|
||||||
|
session_llm: str,
|
||||||
|
state: dict,
|
||||||
|
user_id: int,
|
||||||
|
*selecteds,
|
||||||
|
):
|
||||||
"""Create the pipeline from settings
|
"""Create the pipeline from settings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -374,10 +723,23 @@ class ChatPage(BasePage):
|
||||||
Returns:
|
Returns:
|
||||||
- the pipeline objects
|
- the pipeline objects
|
||||||
"""
|
"""
|
||||||
reasoning_mode = settings["reasoning.use"]
|
# override reasoning_mode by temporary chat page state
|
||||||
|
print("Session reasoning type", session_reasoning_type)
|
||||||
|
print("Session LLM", session_llm)
|
||||||
|
reasoning_mode = (
|
||||||
|
settings["reasoning.use"]
|
||||||
|
if session_reasoning_type in (DEFAULT_SETTING, None)
|
||||||
|
else session_reasoning_type
|
||||||
|
)
|
||||||
reasoning_cls = reasonings[reasoning_mode]
|
reasoning_cls = reasonings[reasoning_mode]
|
||||||
|
print("Reasoning class", reasoning_cls)
|
||||||
reasoning_id = reasoning_cls.get_info()["id"]
|
reasoning_id = reasoning_cls.get_info()["id"]
|
||||||
|
|
||||||
|
settings = deepcopy(settings)
|
||||||
|
llm_setting_key = f"reasoning.options.{reasoning_id}.llm"
|
||||||
|
if llm_setting_key in settings and session_llm not in (DEFAULT_SETTING, None):
|
||||||
|
settings[llm_setting_key] = session_llm
|
||||||
|
|
||||||
# get retrievers
|
# get retrievers
|
||||||
retrievers = []
|
retrievers = []
|
||||||
for index in self._app.index_manager.indices:
|
for index in self._app.index_manager.indices:
|
||||||
|
@ -403,7 +765,15 @@ class ChatPage(BasePage):
|
||||||
return pipeline, reasoning_state
|
return pipeline, reasoning_state
|
||||||
|
|
||||||
def chat_fn(
|
def chat_fn(
|
||||||
self, conversation_id, chat_history, settings, state, user_id, *selecteds
|
self,
|
||||||
|
conversation_id,
|
||||||
|
chat_history,
|
||||||
|
settings,
|
||||||
|
reasoning_type,
|
||||||
|
llm_type,
|
||||||
|
state,
|
||||||
|
user_id,
|
||||||
|
*selecteds,
|
||||||
):
|
):
|
||||||
"""Chat function"""
|
"""Chat function"""
|
||||||
chat_input = chat_history[-1][0]
|
chat_input = chat_history[-1][0]
|
||||||
|
@ -413,18 +783,23 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
# construct the pipeline
|
# construct the pipeline
|
||||||
pipeline, reasoning_state = self.create_pipeline(
|
pipeline, reasoning_state = self.create_pipeline(
|
||||||
settings, state, user_id, *selecteds
|
settings, reasoning_type, llm_type, state, user_id, *selecteds
|
||||||
)
|
)
|
||||||
|
print("Reasoning state", reasoning_state)
|
||||||
pipeline.set_output_queue(queue)
|
pipeline.set_output_queue(queue)
|
||||||
|
|
||||||
text, refs = "", ""
|
text, refs, plot, plot_gr = "", "", None, gr.update(visible=False)
|
||||||
msg_placeholder = getattr(
|
msg_placeholder = getattr(
|
||||||
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
|
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
|
||||||
)
|
)
|
||||||
print(msg_placeholder)
|
print(msg_placeholder)
|
||||||
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
|
yield (
|
||||||
|
chat_history + [(chat_input, text or msg_placeholder)],
|
||||||
len_ref = -1 # for logging purpose
|
refs,
|
||||||
|
plot_gr,
|
||||||
|
plot,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
for response in pipeline.stream(chat_input, conversation_id, chat_history):
|
||||||
|
|
||||||
|
@ -446,22 +821,42 @@ class ChatPage(BasePage):
|
||||||
else:
|
else:
|
||||||
refs += response.content
|
refs += response.content
|
||||||
|
|
||||||
if len(refs) > len_ref:
|
if response.channel == "plot":
|
||||||
print(f"Len refs: {len(refs)}")
|
plot = response.content
|
||||||
len_ref = len(refs)
|
plot_gr = self._json_to_plot(plot)
|
||||||
|
|
||||||
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
|
||||||
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
|
yield (
|
||||||
|
chat_history + [(chat_input, text or msg_placeholder)],
|
||||||
|
refs,
|
||||||
|
plot_gr,
|
||||||
|
plot,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
if not text:
|
if not text:
|
||||||
empty_msg = getattr(
|
empty_msg = getattr(
|
||||||
flowsettings, "KH_CHAT_EMPTY_MSG_PLACEHOLDER", "(Sorry, I don't know)"
|
flowsettings, "KH_CHAT_EMPTY_MSG_PLACEHOLDER", "(Sorry, I don't know)"
|
||||||
)
|
)
|
||||||
print(f"Generate nothing: {empty_msg}")
|
print(f"Generate nothing: {empty_msg}")
|
||||||
yield chat_history + [(chat_input, text or empty_msg)], refs, state
|
yield (
|
||||||
|
chat_history + [(chat_input, text or empty_msg)],
|
||||||
|
refs,
|
||||||
|
plot_gr,
|
||||||
|
plot,
|
||||||
|
state,
|
||||||
|
)
|
||||||
|
|
||||||
def regen_fn(
|
def regen_fn(
|
||||||
self, conversation_id, chat_history, settings, state, user_id, *selecteds
|
self,
|
||||||
|
conversation_id,
|
||||||
|
chat_history,
|
||||||
|
settings,
|
||||||
|
reasoning_type,
|
||||||
|
llm_type,
|
||||||
|
state,
|
||||||
|
user_id,
|
||||||
|
*selecteds,
|
||||||
):
|
):
|
||||||
"""Regen function"""
|
"""Regen function"""
|
||||||
if not chat_history:
|
if not chat_history:
|
||||||
|
@ -470,11 +865,119 @@ class ChatPage(BasePage):
|
||||||
return
|
return
|
||||||
|
|
||||||
state["app"]["regen"] = True
|
state["app"]["regen"] = True
|
||||||
for chat, refs, state in self.chat_fn(
|
yield from self.chat_fn(
|
||||||
conversation_id, chat_history, settings, state, user_id, *selecteds
|
conversation_id,
|
||||||
):
|
chat_history,
|
||||||
new_state = deepcopy(state)
|
settings,
|
||||||
new_state["app"]["regen"] = False
|
reasoning_type,
|
||||||
yield chat, refs, new_state
|
llm_type,
|
||||||
|
state,
|
||||||
|
user_id,
|
||||||
|
*selecteds,
|
||||||
|
)
|
||||||
|
|
||||||
state["app"]["regen"] = False
|
def check_and_suggest_name_conv(self, chat_history):
|
||||||
|
suggest_pipeline = SuggestConvNamePipeline()
|
||||||
|
new_name = gr.update()
|
||||||
|
renamed = False
|
||||||
|
|
||||||
|
# check if this is a newly created conversation
|
||||||
|
if len(chat_history) == 1:
|
||||||
|
suggested_name = suggest_pipeline(chat_history).text[:40]
|
||||||
|
new_name = gr.update(value=suggested_name)
|
||||||
|
renamed = True
|
||||||
|
|
||||||
|
return new_name, renamed
|
||||||
|
|
||||||
|
def backup_original_info(
|
||||||
|
self, chat_history, settings, info_pannel, original_chat_history
|
||||||
|
):
|
||||||
|
original_chat_history.append(chat_history[-1])
|
||||||
|
return original_chat_history, settings, info_pannel
|
||||||
|
|
||||||
|
def save_log(
|
||||||
|
self,
|
||||||
|
conversation_id,
|
||||||
|
chat_history,
|
||||||
|
settings,
|
||||||
|
info_panel,
|
||||||
|
original_chat_history,
|
||||||
|
original_settings,
|
||||||
|
original_info_panel,
|
||||||
|
log_dir,
|
||||||
|
):
|
||||||
|
if not Path(log_dir).exists():
|
||||||
|
Path(log_dir).mkdir(parents=True)
|
||||||
|
|
||||||
|
lock = FileLock(Path(log_dir) / ".lock")
|
||||||
|
# get current date
|
||||||
|
today = datetime.now()
|
||||||
|
formatted_date = today.strftime("%d%m%Y_%H")
|
||||||
|
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
result = session.exec(statement).one()
|
||||||
|
|
||||||
|
data_source = deepcopy(result.data_source)
|
||||||
|
likes = data_source.get("likes", [])
|
||||||
|
if not likes:
|
||||||
|
return
|
||||||
|
|
||||||
|
feedback = likes[-1][-1]
|
||||||
|
message_index = likes[-1][0]
|
||||||
|
|
||||||
|
current_message = chat_history[message_index[0]]
|
||||||
|
original_message = original_chat_history[message_index[0]]
|
||||||
|
is_original = all(
|
||||||
|
[
|
||||||
|
current_item == original_item
|
||||||
|
for current_item, original_item in zip(
|
||||||
|
current_message, original_message
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
dataframe = [
|
||||||
|
[
|
||||||
|
conversation_id,
|
||||||
|
message_index,
|
||||||
|
current_message[0],
|
||||||
|
current_message[1],
|
||||||
|
chat_history,
|
||||||
|
settings,
|
||||||
|
info_panel,
|
||||||
|
feedback,
|
||||||
|
is_original,
|
||||||
|
original_message[1],
|
||||||
|
original_chat_history,
|
||||||
|
original_settings,
|
||||||
|
original_info_panel,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
with lock:
|
||||||
|
log_file = Path(log_dir) / f"{formatted_date}_log.csv"
|
||||||
|
is_log_file_exist = log_file.is_file()
|
||||||
|
with open(log_file, "a") as f:
|
||||||
|
writer = csv.writer(f)
|
||||||
|
# write headers
|
||||||
|
if not is_log_file_exist:
|
||||||
|
writer.writerow(
|
||||||
|
[
|
||||||
|
"Conversation ID",
|
||||||
|
"Message ID",
|
||||||
|
"Question",
|
||||||
|
"Answer",
|
||||||
|
"Chat History",
|
||||||
|
"Settings",
|
||||||
|
"Evidences",
|
||||||
|
"Feedback",
|
||||||
|
"Original/ Rewritten",
|
||||||
|
"Original Answer",
|
||||||
|
"Original Chat History",
|
||||||
|
"Original Settings",
|
||||||
|
"Original Evidences",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.writerows(dataframe)
|
||||||
|
|
|
@ -1,13 +1,20 @@
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
from ktem.app import BasePage
|
from ktem.app import BasePage
|
||||||
from ktem.db.models import Conversation, engine
|
from ktem.db.models import Conversation, User, engine
|
||||||
from sqlmodel import Session, select
|
from sqlmodel import Session, or_, select
|
||||||
|
|
||||||
|
import flowsettings
|
||||||
|
|
||||||
|
from ...utils.conversation import sync_retrieval_n_message
|
||||||
from .common import STATE
|
from .common import STATE
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
ASSETS_DIR = "assets/icons"
|
||||||
|
if not os.path.isdir(ASSETS_DIR):
|
||||||
|
ASSETS_DIR = "libs/ktem/ktem/assets/icons"
|
||||||
|
|
||||||
|
|
||||||
def is_conv_name_valid(name):
|
def is_conv_name_valid(name):
|
||||||
|
@ -35,14 +42,47 @@ class ConversationControl(BasePage):
|
||||||
label="Chat sessions",
|
label="Chat sessions",
|
||||||
choices=[],
|
choices=[],
|
||||||
container=False,
|
container=False,
|
||||||
filterable=False,
|
filterable=True,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
elem_classes=["unset-overflow"],
|
elem_classes=["unset-overflow"],
|
||||||
)
|
)
|
||||||
|
|
||||||
with gr.Row() as self._new_delete:
|
with gr.Row() as self._new_delete:
|
||||||
self.btn_new = gr.Button(value="New", min_width=10, variant="primary")
|
self.btn_new = gr.Button(
|
||||||
self.btn_del = gr.Button(value="Delete", min_width=10, variant="stop")
|
value="",
|
||||||
|
icon=f"{ASSETS_DIR}/new.svg",
|
||||||
|
min_width=2,
|
||||||
|
scale=1,
|
||||||
|
size="sm",
|
||||||
|
elem_classes=["no-background", "body-text-color"],
|
||||||
|
)
|
||||||
|
self.btn_del = gr.Button(
|
||||||
|
value="",
|
||||||
|
icon=f"{ASSETS_DIR}/delete.svg",
|
||||||
|
min_width=2,
|
||||||
|
scale=1,
|
||||||
|
size="sm",
|
||||||
|
elem_classes=["no-background", "body-text-color"],
|
||||||
|
)
|
||||||
|
self.btn_conversation_rn = gr.Button(
|
||||||
|
value="",
|
||||||
|
icon=f"{ASSETS_DIR}/rename.svg",
|
||||||
|
min_width=2,
|
||||||
|
scale=1,
|
||||||
|
size="sm",
|
||||||
|
elem_classes=["no-background", "body-text-color"],
|
||||||
|
)
|
||||||
|
self.btn_info_expand = gr.Button(
|
||||||
|
value="",
|
||||||
|
icon=f"{ASSETS_DIR}/sidebar.svg",
|
||||||
|
min_width=2,
|
||||||
|
scale=1,
|
||||||
|
size="sm",
|
||||||
|
elem_classes=["no-background", "body-text-color"],
|
||||||
|
)
|
||||||
|
self.cb_is_public = gr.Checkbox(
|
||||||
|
value=False, label="Shared", min_width=10, scale=4
|
||||||
|
)
|
||||||
|
|
||||||
with gr.Row(visible=False) as self._delete_confirm:
|
with gr.Row(visible=False) as self._delete_confirm:
|
||||||
self.btn_del_conf = gr.Button(
|
self.btn_del_conf = gr.Button(
|
||||||
|
@ -54,28 +94,60 @@ class ConversationControl(BasePage):
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
self.conversation_rn = gr.Text(
|
self.conversation_rn = gr.Text(
|
||||||
|
label="(Enter) to save",
|
||||||
placeholder="Conversation name",
|
placeholder="Conversation name",
|
||||||
container=False,
|
container=True,
|
||||||
scale=5,
|
scale=5,
|
||||||
min_width=10,
|
min_width=10,
|
||||||
interactive=True,
|
interactive=True,
|
||||||
)
|
visible=False,
|
||||||
self.conversation_rn_btn = gr.Button(
|
|
||||||
value="Rename",
|
|
||||||
scale=1,
|
|
||||||
min_width=10,
|
|
||||||
elem_classes=["no-background", "body-text-color", "bold-text"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def load_chat_history(self, user_id):
|
def load_chat_history(self, user_id):
|
||||||
"""Reload chat history"""
|
"""Reload chat history"""
|
||||||
|
|
||||||
|
# In case user are admin. They can also watch the
|
||||||
|
# public conversations
|
||||||
|
can_see_public: bool = False
|
||||||
|
with Session(engine) as session:
|
||||||
|
statement = select(User).where(User.id == user_id)
|
||||||
|
result = session.exec(statement).one_or_none()
|
||||||
|
|
||||||
|
if result is not None:
|
||||||
|
if flowsettings.KH_USER_CAN_SEE_PUBLIC:
|
||||||
|
can_see_public = (
|
||||||
|
result.username == flowsettings.KH_USER_CAN_SEE_PUBLIC
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
can_see_public = True
|
||||||
|
|
||||||
|
print(f"User-id: {user_id}, can see public conversations: {can_see_public}")
|
||||||
|
|
||||||
options = []
|
options = []
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
statement = (
|
# Define condition based on admin-role:
|
||||||
select(Conversation)
|
# - can_see: can see their conversations & public files
|
||||||
.where(Conversation.user == user_id)
|
# - can_not_see: only see their conversations
|
||||||
.order_by(Conversation.date_created.desc()) # type: ignore
|
if can_see_public:
|
||||||
)
|
statement = (
|
||||||
|
select(Conversation)
|
||||||
|
.where(
|
||||||
|
or_(
|
||||||
|
Conversation.user == user_id,
|
||||||
|
Conversation.is_public,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.order_by(
|
||||||
|
Conversation.is_public.desc(), Conversation.date_created.desc()
|
||||||
|
) # type: ignore
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
statement = (
|
||||||
|
select(Conversation)
|
||||||
|
.where(Conversation.user == user_id)
|
||||||
|
.order_by(Conversation.date_created.desc()) # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
results = session.exec(statement).all()
|
results = session.exec(statement).all()
|
||||||
for result in results:
|
for result in results:
|
||||||
options.append((result.name, result.id))
|
options.append((result.name, result.id))
|
||||||
|
@ -129,7 +201,7 @@ class ConversationControl(BasePage):
|
||||||
else:
|
else:
|
||||||
return None, gr.update(value=None, choices=[])
|
return None, gr.update(value=None, choices=[])
|
||||||
|
|
||||||
def select_conv(self, conversation_id):
|
def select_conv(self, conversation_id, user_id):
|
||||||
"""Select the conversation"""
|
"""Select the conversation"""
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
statement = select(Conversation).where(Conversation.id == conversation_id)
|
statement = select(Conversation).where(Conversation.id == conversation_id)
|
||||||
|
@ -137,18 +209,46 @@ class ConversationControl(BasePage):
|
||||||
result = session.exec(statement).one()
|
result = session.exec(statement).one()
|
||||||
id_ = result.id
|
id_ = result.id
|
||||||
name = result.name
|
name = result.name
|
||||||
selected = result.data_source.get("selected", {})
|
is_conv_public = result.is_public
|
||||||
|
|
||||||
|
# disable file selection ids state if
|
||||||
|
# not the owner of the conversation
|
||||||
|
if user_id == result.user:
|
||||||
|
selected = result.data_source.get("selected", {})
|
||||||
|
else:
|
||||||
|
selected = {}
|
||||||
|
|
||||||
chats = result.data_source.get("messages", [])
|
chats = result.data_source.get("messages", [])
|
||||||
info_panel = ""
|
|
||||||
|
retrieval_history: list[str] = result.data_source.get(
|
||||||
|
"retrieval_messages", []
|
||||||
|
)
|
||||||
|
plot_history: list[dict] = result.data_source.get("plot_history", [])
|
||||||
|
|
||||||
|
# On initialization
|
||||||
|
# Ensure len of retrieval and messages are equal
|
||||||
|
retrieval_history = sync_retrieval_n_message(chats, retrieval_history)
|
||||||
|
|
||||||
|
info_panel = (
|
||||||
|
retrieval_history[-1]
|
||||||
|
if retrieval_history
|
||||||
|
else "<h5><b>No evidence found.</b></h5>"
|
||||||
|
)
|
||||||
|
plot_data = plot_history[-1] if plot_history else None
|
||||||
state = result.data_source.get("state", STATE)
|
state = result.data_source.get("state", STATE)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(e)
|
logger.warning(e)
|
||||||
id_ = ""
|
id_ = ""
|
||||||
name = ""
|
name = ""
|
||||||
selected = {}
|
selected = {}
|
||||||
chats = []
|
chats = []
|
||||||
|
retrieval_history = []
|
||||||
|
plot_history = []
|
||||||
info_panel = ""
|
info_panel = ""
|
||||||
|
plot_data = None
|
||||||
state = STATE
|
state = STATE
|
||||||
|
is_conv_public = False
|
||||||
|
|
||||||
indices = []
|
indices = []
|
||||||
for index in self._app.index_manager.indices:
|
for index in self._app.index_manager.indices:
|
||||||
|
@ -160,10 +260,29 @@ class ConversationControl(BasePage):
|
||||||
if isinstance(index.selector, tuple):
|
if isinstance(index.selector, tuple):
|
||||||
indices.extend(selected.get(str(index.id), index.default_selector))
|
indices.extend(selected.get(str(index.id), index.default_selector))
|
||||||
|
|
||||||
return id_, id_, name, chats, info_panel, state, *indices
|
return (
|
||||||
|
id_,
|
||||||
|
id_,
|
||||||
|
name,
|
||||||
|
chats,
|
||||||
|
info_panel,
|
||||||
|
plot_data,
|
||||||
|
retrieval_history,
|
||||||
|
plot_history,
|
||||||
|
is_conv_public,
|
||||||
|
state,
|
||||||
|
*indices,
|
||||||
|
)
|
||||||
|
|
||||||
def rename_conv(self, conversation_id, new_name, user_id):
|
def rename_conv(self, conversation_id, new_name, is_renamed, user_id):
|
||||||
"""Rename the conversation"""
|
"""Rename the conversation"""
|
||||||
|
if not is_renamed:
|
||||||
|
return (
|
||||||
|
gr.update(),
|
||||||
|
conversation_id,
|
||||||
|
gr.update(visible=False),
|
||||||
|
)
|
||||||
|
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
gr.Warning("Please sign in first (Settings → User Settings)")
|
gr.Warning("Please sign in first (Settings → User Settings)")
|
||||||
return gr.update(), ""
|
return gr.update(), ""
|
||||||
|
@ -185,7 +304,12 @@ class ConversationControl(BasePage):
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
history = self.load_chat_history(user_id)
|
history = self.load_chat_history(user_id)
|
||||||
return gr.update(choices=history), conversation_id
|
gr.Info("Conversation renamed.")
|
||||||
|
return (
|
||||||
|
gr.update(choices=history),
|
||||||
|
conversation_id,
|
||||||
|
gr.update(visible=False),
|
||||||
|
)
|
||||||
|
|
||||||
def _on_app_created(self):
|
def _on_app_created(self):
|
||||||
"""Reload the conversation once the app is created"""
|
"""Reload the conversation once the app is created"""
|
||||||
|
|
|
@ -12,7 +12,7 @@ class ReportIssue(BasePage):
|
||||||
self.on_building_ui()
|
self.on_building_ui()
|
||||||
|
|
||||||
def on_building_ui(self):
|
def on_building_ui(self):
|
||||||
with gr.Accordion(label="Report", open=False):
|
with gr.Accordion(label="Feedback", open=False):
|
||||||
self.correctness = gr.Radio(
|
self.correctness = gr.Radio(
|
||||||
choices=[
|
choices=[
|
||||||
("The answer is correct", "correct"),
|
("The answer is correct", "correct"),
|
||||||
|
|
|
@ -9,6 +9,7 @@ from theflow.settings import settings
|
||||||
def get_remote_doc(url: str) -> str:
|
def get_remote_doc(url: str) -> str:
|
||||||
try:
|
try:
|
||||||
res = requests.get(url)
|
res = requests.get(url)
|
||||||
|
res.raise_for_status()
|
||||||
return res.text
|
return res.text
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Failed to fetch document from {url}: {e}")
|
print(f"Failed to fetch document from {url}: {e}")
|
||||||
|
|