(pump:minor) Allow the indexing pipeline to report the indexing progress onto the UI (#81)
* Turn the file indexing event to generator to report progress * Fix React text's trimming function * Refactor delete file into a method
This commit is contained in:
parent
56dfc8fb53
commit
ebf1315569
|
@ -1,5 +1,5 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import AsyncGenerator, Iterator, Optional
|
from typing import Any, AsyncGenerator, Iterator, Optional
|
||||||
|
|
||||||
from theflow import Function, Node, Param, lazy
|
from theflow import Function, Node, Param, lazy
|
||||||
|
|
||||||
|
@ -58,7 +58,7 @@ class BaseComponent(Function):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def run(
|
def run(
|
||||||
self, *args, **kwargs
|
self, *args, **kwargs
|
||||||
) -> Document | list[Document] | Iterator[Document] | None:
|
) -> Document | list[Document] | Iterator[Document] | None | Any:
|
||||||
"""Run the component."""
|
"""Run the component."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
|
@ -32,12 +32,13 @@ class Document(BaseDocument):
|
||||||
channel: the channel to show the document. Optional.:
|
channel: the channel to show the document. Optional.:
|
||||||
- chat: show in chat message
|
- chat: show in chat message
|
||||||
- info: show in information panel
|
- info: show in information panel
|
||||||
|
- index: show in index panel
|
||||||
- debug: show in debug panel
|
- debug: show in debug panel
|
||||||
"""
|
"""
|
||||||
|
|
||||||
content: Any = None
|
content: Any = None
|
||||||
source: Optional[str] = None
|
source: Optional[str] = None
|
||||||
channel: Optional[Literal["chat", "info", "debug"]] = None
|
channel: Optional[Literal["chat", "info", "index", "debug"]] = None
|
||||||
|
|
||||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||||
if content is None:
|
if content is None:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
|
from llama_index.readers import PDFReader
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.readers.base import BaseReader
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent, Document, Param
|
from kotaemon.base import BaseComponent, Document, Param
|
||||||
|
@ -17,18 +18,20 @@ from kotaemon.loaders import (
|
||||||
UnstructuredReader,
|
UnstructuredReader,
|
||||||
)
|
)
|
||||||
|
|
||||||
KH_DEFAULT_FILE_EXTRACTORS: dict[str, Type[BaseReader]] = {
|
unstructured = UnstructuredReader()
|
||||||
".xlsx": PandasExcelReader,
|
KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
||||||
".docx": UnstructuredReader,
|
".xlsx": PandasExcelReader(),
|
||||||
".xls": UnstructuredReader,
|
".docx": unstructured,
|
||||||
".doc": UnstructuredReader,
|
".xls": unstructured,
|
||||||
".html": HtmlReader,
|
".doc": unstructured,
|
||||||
".mhtml": MhtmlReader,
|
".html": HtmlReader(),
|
||||||
".png": UnstructuredReader,
|
".mhtml": MhtmlReader(),
|
||||||
".jpeg": UnstructuredReader,
|
".png": unstructured,
|
||||||
".jpg": UnstructuredReader,
|
".jpeg": unstructured,
|
||||||
".tiff": UnstructuredReader,
|
".jpg": unstructured,
|
||||||
".tif": UnstructuredReader,
|
".tiff": unstructured,
|
||||||
|
".tif": unstructured,
|
||||||
|
".pdf": PDFReader(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,7 +67,7 @@ class DocumentIngestor(BaseComponent):
|
||||||
def _get_reader(self, input_files: list[str | Path]):
|
def _get_reader(self, input_files: list[str | Path]):
|
||||||
"""Get appropriate readers for the input files based on file extension"""
|
"""Get appropriate readers for the input files based on file extension"""
|
||||||
file_extractors: dict[str, BaseReader] = {
|
file_extractors: dict[str, BaseReader] = {
|
||||||
ext: cls() for ext, cls in KH_DEFAULT_FILE_EXTRACTORS.items()
|
ext: reader for ext, reader in KH_DEFAULT_FILE_EXTRACTORS.items()
|
||||||
}
|
}
|
||||||
for ext, cls in self.override_file_extractors.items():
|
for ext, cls in self.override_file_extractors.items():
|
||||||
file_extractors[ext] = cls()
|
file_extractors[ext] = cls()
|
||||||
|
|
|
@ -8,6 +8,8 @@ if TYPE_CHECKING:
|
||||||
|
|
||||||
|
|
||||||
class BaseReader(BaseComponent):
|
class BaseReader(BaseComponent):
|
||||||
|
"""The base class for all readers"""
|
||||||
|
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -126,7 +126,7 @@ class BaseIndex(abc.ABC):
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
self, settings: dict, selected: Any = None
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
) -> list["BaseComponent"]:
|
) -> list["BaseComponent"]:
|
||||||
"""Return the retriever pipelines to retrieve the entity from the index"""
|
"""Return the retriever pipelines to retrieve the entity from the index"""
|
||||||
return []
|
return []
|
||||||
|
|
|
@ -1,10 +1,18 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Generator, Optional
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent, Document, Param
|
||||||
|
|
||||||
|
|
||||||
class BaseFileIndexRetriever(BaseComponent):
|
class BaseFileIndexRetriever(BaseComponent):
|
||||||
|
|
||||||
|
Source = Param(help="The SQLAlchemy Source table")
|
||||||
|
Index = Param(help="The SQLAlchemy Index table")
|
||||||
|
VS = Param(help="The VectorStore")
|
||||||
|
DS = Param(help="The DocStore")
|
||||||
|
FSPath = Param(help="The file storage path")
|
||||||
|
user_id = Param(help="The user id")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_user_settings(cls) -> dict:
|
def get_user_settings(cls) -> dict:
|
||||||
"""Get the user settings for indexing
|
"""Get the user settings for indexing
|
||||||
|
@ -24,20 +32,6 @@ class BaseFileIndexRetriever(BaseComponent):
|
||||||
) -> "BaseFileIndexRetriever":
|
) -> "BaseFileIndexRetriever":
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_resources(self, resources: dict):
|
|
||||||
"""Set the resources for the indexing pipeline
|
|
||||||
|
|
||||||
This will setup the tables, the vector store and docstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resources (dict): the resources for the indexing pipeline
|
|
||||||
"""
|
|
||||||
self._Source = resources["Source"]
|
|
||||||
self._Index = resources["Index"]
|
|
||||||
self._VS = resources["VectorStore"]
|
|
||||||
self._DS = resources["DocStore"]
|
|
||||||
self._fs_path = resources["FileStoragePath"]
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFileIndexIndexing(BaseComponent):
|
class BaseFileIndexIndexing(BaseComponent):
|
||||||
"""The pipeline to index information into the data store
|
"""The pipeline to index information into the data store
|
||||||
|
@ -54,11 +48,45 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
- self._DS: the docstore
|
- self._DS: the docstore
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def run(self, file_paths: str | Path | list[str | Path], *args, **kwargs):
|
Source = Param(help="The SQLAlchemy Source table")
|
||||||
|
Index = Param(help="The SQLAlchemy Index table")
|
||||||
|
VS = Param(help="The VectorStore")
|
||||||
|
DS = Param(help="The DocStore")
|
||||||
|
FSPath = Param(help="The file storage path")
|
||||||
|
user_id = Param(help="The user id")
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||||
|
) -> tuple[list[str | None], list[str | None]]:
|
||||||
"""Run the indexing pipeline
|
"""Run the indexing pipeline
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_paths (str | Path | list[str | Path]): the file paths to index
|
file_paths (str | Path | list[str | Path]): the file paths to index
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- the indexed file ids (each file id corresponds to an input file path, or
|
||||||
|
None if the indexing failed for that file path)
|
||||||
|
- the error messages (each error message corresponds to an input file path,
|
||||||
|
or None if the indexing was successful for that file path)
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||||
|
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
|
||||||
|
"""Stream the indexing pipeline
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_paths (str | Path | list[str | Path]): the file paths to index
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Document: the output message to the UI, must have channel == index or debug
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- the indexed file ids (each file id corresponds to an input file path, or
|
||||||
|
None if the indexing failed for that file path)
|
||||||
|
- the error messages (each error message corresponds to an input file path,
|
||||||
|
or None if the indexing was successful for that file path)
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@ -78,20 +106,6 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
"""
|
"""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def set_resources(self, resources: dict):
|
|
||||||
"""Set the resources for the indexing pipeline
|
|
||||||
|
|
||||||
This will setup the tables, the vector store and docstore.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resources (dict): the resources for the indexing pipeline
|
|
||||||
"""
|
|
||||||
self._Source = resources["Source"]
|
|
||||||
self._Index = resources["Index"]
|
|
||||||
self._VS = resources["VectorStore"]
|
|
||||||
self._DS = resources["DocStore"]
|
|
||||||
self._fs_path = resources["FileStoragePath"]
|
|
||||||
|
|
||||||
def copy_to_filestorage(
|
def copy_to_filestorage(
|
||||||
self, file_paths: str | Path | list[str | Path]
|
self, file_paths: str | Path | list[str | Path]
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
|
@ -113,7 +127,7 @@ class BaseFileIndexIndexing(BaseComponent):
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
with open(file_path, "rb") as f:
|
with open(file_path, "rb") as f:
|
||||||
paths.append(sha256(f.read()).hexdigest())
|
paths.append(sha256(f.read()).hexdigest())
|
||||||
shutil.copy(file_path, self._fs_path / paths[-1])
|
shutil.copy(file_path, self.FSPath / paths[-1])
|
||||||
|
|
||||||
return paths
|
return paths
|
||||||
|
|
||||||
|
|
|
@ -362,13 +362,17 @@ class FileIndex(BaseIndex):
|
||||||
stripped_settings[key] = value
|
stripped_settings[key] = value
|
||||||
|
|
||||||
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
||||||
obj.set_resources(resources=self._resources)
|
obj.Source = self._resources["Source"]
|
||||||
obj._user_id = user_id
|
obj.Index = self._resources["Index"]
|
||||||
|
obj.VS = self._vs
|
||||||
|
obj.DS = self._docstore
|
||||||
|
obj.FSPath = self._fs_path
|
||||||
|
obj.user_id = user_id
|
||||||
|
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def get_retriever_pipelines(
|
def get_retriever_pipelines(
|
||||||
self, settings: dict, selected: Any = None
|
self, settings: dict, user_id: int, selected: Any = None
|
||||||
) -> list["BaseFileIndexRetriever"]:
|
) -> list["BaseFileIndexRetriever"]:
|
||||||
# retrieval settings
|
# retrieval settings
|
||||||
prefix = f"index.options.{self.id}."
|
prefix = f"index.options.{self.id}."
|
||||||
|
@ -387,7 +391,12 @@ class FileIndex(BaseIndex):
|
||||||
obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
|
obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
continue
|
continue
|
||||||
obj.set_resources(self._resources)
|
obj.Source = self._resources["Source"]
|
||||||
|
obj.Index = self._resources["Index"]
|
||||||
|
obj.VS = self._vs
|
||||||
|
obj.DS = self._docstore
|
||||||
|
obj.FSPath = self._fs_path
|
||||||
|
obj.user_id = user_id
|
||||||
retrievers.append(obj)
|
retrievers.append(obj)
|
||||||
|
|
||||||
return retrievers
|
return retrievers
|
||||||
|
|
|
@ -7,13 +7,13 @@ from collections import defaultdict
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Generator, Optional
|
||||||
|
|
||||||
import gradio as gr
|
|
||||||
from ktem.components import filestorage_path
|
|
||||||
from ktem.db.models import engine
|
from ktem.db.models import engine
|
||||||
from ktem.embeddings.manager import embedding_models_manager
|
from ktem.embeddings.manager import embedding_models_manager
|
||||||
from ktem.llms.manager import llms
|
from ktem.llms.manager import llms
|
||||||
|
from llama_index.readers.base import BaseReader
|
||||||
|
from llama_index.readers.file.base import default_file_metadata_func
|
||||||
from llama_index.vector_stores import (
|
from llama_index.vector_stores import (
|
||||||
FilterCondition,
|
FilterCondition,
|
||||||
FilterOperator,
|
FilterOperator,
|
||||||
|
@ -26,10 +26,12 @@ from sqlalchemy.orm import Session
|
||||||
from theflow.settings import settings
|
from theflow.settings import settings
|
||||||
from theflow.utils.modules import import_dotted_string
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
from kotaemon.base import RetrievedDocument
|
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
|
||||||
|
from kotaemon.embeddings import BaseEmbeddings
|
||||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||||
from kotaemon.indices.ingests import DocumentIngestor
|
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking
|
from kotaemon.indices.rankings import BaseReranking, LLMReranking
|
||||||
|
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||||
|
|
||||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||||
|
|
||||||
|
@ -43,7 +45,7 @@ def dev_settings():
|
||||||
|
|
||||||
if hasattr(settings, "FILE_INDEX_PIPELINE_FILE_EXTRACTORS"):
|
if hasattr(settings, "FILE_INDEX_PIPELINE_FILE_EXTRACTORS"):
|
||||||
file_extractors = {
|
file_extractors = {
|
||||||
key: import_dotted_string(value, safe=False)
|
key: import_dotted_string(value, safe=False)()
|
||||||
for key, value in settings.FILE_INDEX_PIPELINE_FILE_EXTRACTORS.items()
|
for key, value in settings.FILE_INDEX_PIPELINE_FILE_EXTRACTORS.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -72,12 +74,20 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
mmr: whether to use mmr to re-rank the documents
|
mmr: whether to use mmr to re-rank the documents
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
|
embedding: BaseEmbeddings
|
||||||
reranker: BaseReranking = LLMReranking.withx()
|
reranker: BaseReranking = LLMReranking.withx()
|
||||||
get_extra_table: bool = False
|
get_extra_table: bool = False
|
||||||
mmr: bool = False
|
mmr: bool = False
|
||||||
top_k: int = 5
|
top_k: int = 5
|
||||||
|
|
||||||
|
@Node.auto(depends_on=["embedding", "VS", "DS"])
|
||||||
|
def vector_retrieval(self) -> VectorRetrieval:
|
||||||
|
return VectorRetrieval(
|
||||||
|
embedding=self.embedding,
|
||||||
|
vector_store=self.VS,
|
||||||
|
doc_store=self.DS,
|
||||||
|
)
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
|
@ -95,13 +105,11 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
logger.info(f"Skip retrieval because of no selected files: {self}")
|
logger.info(f"Skip retrieval because of no selected files: {self}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
Index = self._Index
|
|
||||||
|
|
||||||
retrieval_kwargs = {}
|
retrieval_kwargs = {}
|
||||||
with Session(engine) as session:
|
with Session(engine) as session:
|
||||||
stmt = select(Index).where(
|
stmt = select(self.Index).where(
|
||||||
Index.relation_type == "vector",
|
self.Index.relation_type == "vector",
|
||||||
Index.source_id.in_(doc_ids), # type: ignore
|
self.Index.source_id.in_(doc_ids),
|
||||||
)
|
)
|
||||||
results = session.execute(stmt)
|
results = session.execute(stmt)
|
||||||
vs_ids = [r[0].target_id for r in results.all()]
|
vs_ids = [r[0].target_id for r in results.all()]
|
||||||
|
@ -186,7 +194,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
"component": "dropdown",
|
"component": "dropdown",
|
||||||
},
|
},
|
||||||
"num_retrieval": {
|
"num_retrieval": {
|
||||||
"name": "Number of documents to retrieve",
|
"name": "Number of document chunks to retrieve",
|
||||||
"value": 3,
|
"value": 3,
|
||||||
"component": "number",
|
"component": "number",
|
||||||
},
|
},
|
||||||
|
@ -228,6 +236,11 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
get_extra_table=user_settings["prioritize_table"],
|
get_extra_table=user_settings["prioritize_table"],
|
||||||
top_k=user_settings["num_retrieval"],
|
top_k=user_settings["num_retrieval"],
|
||||||
mmr=user_settings["mmr"],
|
mmr=user_settings["mmr"],
|
||||||
|
embedding=embedding_models_manager[
|
||||||
|
index_settings.get(
|
||||||
|
"embedding", embedding_models_manager.get_default_name()
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
if not user_settings["use_reranking"]:
|
if not user_settings["use_reranking"]:
|
||||||
retriever.reranker = None # type: ignore
|
retriever.reranker = None # type: ignore
|
||||||
|
@ -236,226 +249,346 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
||||||
user_settings["reranking_llm"], llms.get_default()
|
user_settings["reranking_llm"], llms.get_default()
|
||||||
)
|
)
|
||||||
|
|
||||||
retriever.vector_retrieval.embedding = embedding_models_manager[
|
|
||||||
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
|
||||||
]
|
|
||||||
kwargs = {".doc_ids": selected}
|
kwargs = {".doc_ids": selected}
|
||||||
retriever.set_run(kwargs, temp=True)
|
retriever.set_run(kwargs, temp=True)
|
||||||
return retriever
|
return retriever
|
||||||
|
|
||||||
def set_resources(self, resources: dict):
|
|
||||||
super().set_resources(resources)
|
class IndexPipeline(BaseComponent):
|
||||||
self.vector_retrieval.vector_store = self._VS
|
"""Index a single file"""
|
||||||
self.vector_retrieval.doc_store = self._DS
|
|
||||||
|
loader: BaseReader
|
||||||
|
splitter: BaseSplitter
|
||||||
|
chunk_batch_size: int = 50
|
||||||
|
|
||||||
|
Source = Param(help="The SQLAlchemy Source table")
|
||||||
|
Index = Param(help="The SQLAlchemy Index table")
|
||||||
|
VS = Param(help="The VectorStore")
|
||||||
|
DS = Param(help="The DocStore")
|
||||||
|
FSPath = Param(help="The file storage path")
|
||||||
|
user_id = Param(help="The user id")
|
||||||
|
embedding: BaseEmbeddings
|
||||||
|
|
||||||
|
@Node.auto(depends_on=["Source", "Index", "embedding"])
|
||||||
|
def vector_indexing(self) -> VectorIndexing:
|
||||||
|
return VectorIndexing(
|
||||||
|
vector_store=self.VS, doc_store=self.DS, embedding=self.embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
|
||||||
|
chunks = []
|
||||||
|
n_chunks = 0
|
||||||
|
for cidx, chunk in enumerate(self.splitter(docs)):
|
||||||
|
chunks.append(chunk)
|
||||||
|
if cidx % self.chunk_batch_size == 0:
|
||||||
|
self.handle_chunks(chunks, file_id)
|
||||||
|
n_chunks += len(chunks)
|
||||||
|
chunks = []
|
||||||
|
yield Document(
|
||||||
|
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
if chunks:
|
||||||
|
self.handle_chunks(chunks, file_id)
|
||||||
|
n_chunks += len(chunks)
|
||||||
|
yield Document(
|
||||||
|
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
|
||||||
|
)
|
||||||
|
|
||||||
|
return n_chunks
|
||||||
|
|
||||||
|
def handle_chunks(self, chunks, file_id):
|
||||||
|
"""Run chunks"""
|
||||||
|
# run embedding, add to both vector store and doc store
|
||||||
|
self.vector_indexing(chunks)
|
||||||
|
|
||||||
|
# record in the index
|
||||||
|
with Session(engine) as session:
|
||||||
|
nodes = []
|
||||||
|
for chunk in chunks:
|
||||||
|
nodes.append(
|
||||||
|
self.Index(
|
||||||
|
source_id=file_id,
|
||||||
|
target_id=chunk.doc_id,
|
||||||
|
relation_type="document",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
nodes.append(
|
||||||
|
self.Index(
|
||||||
|
source_id=file_id,
|
||||||
|
target_id=chunk.doc_id,
|
||||||
|
relation_type="vector",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.add_all(nodes)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def get_id_if_exists(self, file_path: Path) -> Optional[str]:
|
||||||
|
"""Check if the file is already indexed
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: the path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the file id if the file is indexed, otherwise None
|
||||||
|
"""
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = select(self.Source).where(self.Source.name == file_path.name)
|
||||||
|
item = session.execute(stmt).first()
|
||||||
|
if item:
|
||||||
|
return item[0].id
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def store_file(self, file_path: Path) -> str:
|
||||||
|
"""Store file into the database and storage, return the file id
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: the path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the file id
|
||||||
|
"""
|
||||||
|
with file_path.open("rb") as fi:
|
||||||
|
file_hash = sha256(fi.read()).hexdigest()
|
||||||
|
|
||||||
|
shutil.copy(file_path, self.FSPath / file_hash)
|
||||||
|
source = self.Source(
|
||||||
|
name=file_path.name,
|
||||||
|
path=file_hash,
|
||||||
|
size=file_path.stat().st_size,
|
||||||
|
user=self.user_id, # type: ignore
|
||||||
|
)
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.add(source)
|
||||||
|
session.commit()
|
||||||
|
file_id = source.id
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
def finish(self, file_id: str, file_path: Path) -> str:
|
||||||
|
"""Finish the indexing"""
|
||||||
|
with Session(engine) as session:
|
||||||
|
stmt = select(self.Index.target_id).where(self.Index.source_id == file_id)
|
||||||
|
doc_ids = [_[0] for _ in session.execute(stmt)]
|
||||||
|
if doc_ids:
|
||||||
|
docs = self.DS.get(doc_ids)
|
||||||
|
stmt = select(self.Source).where(self.Source.id == file_id)
|
||||||
|
result = session.execute(stmt).first()
|
||||||
|
if result:
|
||||||
|
item = result[0]
|
||||||
|
item.text_length = sum([len(doc.text) for doc in docs])
|
||||||
|
session.add(item)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
def delete_file(self, file_id: str):
|
||||||
|
"""Delete a file from the db, including its chunks in docstore and vectorstore
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_id: the file id
|
||||||
|
"""
|
||||||
|
with Session(engine) as session:
|
||||||
|
session.execute(delete(self.Source).where(self.Source.id == file_id))
|
||||||
|
vs_ids, ds_ids = [], []
|
||||||
|
index = session.execute(
|
||||||
|
select(self.Index).where(self.Index.source_id == file_id)
|
||||||
|
).all()
|
||||||
|
for each in index:
|
||||||
|
if each[0].relation_type == "vector":
|
||||||
|
vs_ids.append(each[0].target_id)
|
||||||
|
else:
|
||||||
|
ds_ids.append(each[0].target_id)
|
||||||
|
session.delete(each[0])
|
||||||
|
session.commit()
|
||||||
|
self.VS.delete(vs_ids)
|
||||||
|
self.DS.delete(ds_ids)
|
||||||
|
|
||||||
|
def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
|
||||||
|
"""Index the file and return the file id"""
|
||||||
|
# check for duplication
|
||||||
|
file_path = Path(file_path).resolve()
|
||||||
|
file_id = self.get_id_if_exists(file_path)
|
||||||
|
if file_id is not None:
|
||||||
|
if not reindex:
|
||||||
|
raise ValueError(
|
||||||
|
f"File {file_path.name} already indexed. Please rerun with "
|
||||||
|
"reindex=True to force reindexing."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# remove the existing records
|
||||||
|
self.delete_file(file_id)
|
||||||
|
file_id = self.store_file(file_path)
|
||||||
|
else:
|
||||||
|
# add record to db
|
||||||
|
file_id = self.store_file(file_path)
|
||||||
|
|
||||||
|
# extract the file
|
||||||
|
extra_info = default_file_metadata_func(str(file_path))
|
||||||
|
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
||||||
|
for _ in self.handle_docs(docs, file_id, file_path.name):
|
||||||
|
continue
|
||||||
|
self.finish(file_id, file_path)
|
||||||
|
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
def stream(
|
||||||
|
self, file_path: str | Path, reindex: bool, **kwargs
|
||||||
|
) -> Generator[Document, None, str]:
|
||||||
|
# check for duplication
|
||||||
|
file_path = Path(file_path).resolve()
|
||||||
|
file_id = self.get_id_if_exists(file_path)
|
||||||
|
if file_id is not None:
|
||||||
|
if not reindex:
|
||||||
|
raise ValueError(
|
||||||
|
f"File {file_path.name} already indexed. Please rerun with "
|
||||||
|
"reindex=True to force reindexing."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# remove the existing records
|
||||||
|
yield Document(f" => Removing old {file_path.name}", channel="debug")
|
||||||
|
self.delete_file(file_id)
|
||||||
|
file_id = self.store_file(file_path)
|
||||||
|
else:
|
||||||
|
# add record to db
|
||||||
|
file_id = self.store_file(file_path)
|
||||||
|
|
||||||
|
# extract the file
|
||||||
|
extra_info = default_file_metadata_func(str(file_path))
|
||||||
|
yield Document(f" => Converting {file_path.name} to text", channel="debug")
|
||||||
|
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
||||||
|
yield Document(f" => Converted {file_path.name} to text", channel="debug")
|
||||||
|
yield from self.handle_docs(docs, file_id, file_path.name)
|
||||||
|
|
||||||
|
self.finish(file_id, file_path)
|
||||||
|
|
||||||
|
yield Document(f" => Finished indexing {file_path.name}", channel="debug")
|
||||||
|
return file_id
|
||||||
|
|
||||||
|
|
||||||
class IndexDocumentPipeline(BaseFileIndexIndexing):
|
class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||||
"""Store the documents and index the content into vector store and doc store
|
"""Index the file. Decide which pipeline based on the file type.
|
||||||
|
|
||||||
Args:
|
This method is essentially a factory to decide which indexing pipeline to use.
|
||||||
indexing_vector_pipeline: pipeline to index the documents
|
|
||||||
file_ingestor: ingestor to ingest the documents
|
We can decide the pipeline programmatically, and/or automatically based on an LLM.
|
||||||
|
If we based on the LLM, essentially we will log the LLM thought process in a file,
|
||||||
|
and then during the indexing, we will read that file to decide which pipeline
|
||||||
|
to use, and then log the operation in that file. Overtime, the LLM can learn to
|
||||||
|
decide which pipeline should be used.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx()
|
embedding: BaseEmbeddings
|
||||||
file_ingestor: DocumentIngestor = DocumentIngestor.withx()
|
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
|
||||||
|
obj = cls(
|
||||||
|
embedding=embedding_models_manager[
|
||||||
|
index_settings.get(
|
||||||
|
"embedding", embedding_models_manager.get_default_name()
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
def route(self, file_path: Path) -> IndexPipeline:
|
||||||
|
"""Decide the pipeline based on the file type
|
||||||
|
|
||||||
|
Can subclass this method for a more elaborate pipeline routing strategy.
|
||||||
|
"""
|
||||||
|
readers, chunk_size, chunk_overlap = dev_settings()
|
||||||
|
|
||||||
|
ext = file_path.suffix
|
||||||
|
reader = readers.get(ext, KH_DEFAULT_FILE_EXTRACTORS.get(ext, None))
|
||||||
|
if reader is None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"No supported pipeline to index {file_path.name}. Please specify "
|
||||||
|
"the suitable pipeline for this file type in the settings."
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline: IndexPipeline = IndexPipeline(
|
||||||
|
loader=reader,
|
||||||
|
splitter=TokenSplitter(
|
||||||
|
chunk_size=chunk_size or 1024,
|
||||||
|
chunk_overlap=chunk_overlap or 256,
|
||||||
|
separator="\n\n",
|
||||||
|
backup_separators=["\n", ".", "\u200B"],
|
||||||
|
),
|
||||||
|
Source=self.Source,
|
||||||
|
Index=self.Index,
|
||||||
|
VS=self.VS,
|
||||||
|
DS=self.DS,
|
||||||
|
FSPath=self.FSPath,
|
||||||
|
user_id=self.user_id,
|
||||||
|
embedding=self.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
def run(
|
def run(
|
||||||
self,
|
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||||
file_paths: str | Path | list[str | Path],
|
) -> tuple[list[str | None], list[str | None]]:
|
||||||
reindex: bool = False,
|
"""Return a list of indexed file ids, and a list of errors"""
|
||||||
**kwargs, # type: ignore
|
|
||||||
):
|
|
||||||
"""Index the list of documents
|
|
||||||
|
|
||||||
This function will extract the files, persist the files to storage,
|
|
||||||
index the files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_paths: list of file paths to index
|
|
||||||
reindex: whether to force reindexing the files if they exist
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list of split nodes
|
|
||||||
"""
|
|
||||||
Source = self._Source
|
|
||||||
Index = self._Index
|
|
||||||
|
|
||||||
if not isinstance(file_paths, list):
|
if not isinstance(file_paths, list):
|
||||||
file_paths = [file_paths]
|
file_paths = [file_paths]
|
||||||
|
|
||||||
to_index: list[str] = []
|
file_ids: list[str | None] = []
|
||||||
file_to_hash: dict[str, str] = {}
|
errors: list[str | None] = []
|
||||||
errors = []
|
|
||||||
to_update = []
|
|
||||||
|
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
abs_path = str(Path(file_path).resolve())
|
file_path = Path(file_path)
|
||||||
with open(abs_path, "rb") as fi:
|
|
||||||
file_hash = sha256(fi.read()).hexdigest()
|
|
||||||
|
|
||||||
file_to_hash[abs_path] = file_hash
|
try:
|
||||||
|
pipeline = self.route(file_path)
|
||||||
|
file_id = pipeline.run(file_path, reindex=reindex, **kwargs)
|
||||||
|
file_ids.append(file_id)
|
||||||
|
errors.append(None)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
file_ids.append(None)
|
||||||
|
errors.append(str(e))
|
||||||
|
|
||||||
with Session(engine) as session:
|
return file_ids, errors
|
||||||
statement = select(Source).where(Source.name == Path(abs_path).name)
|
|
||||||
item = session.execute(statement).first()
|
|
||||||
|
|
||||||
if item:
|
def stream(
|
||||||
if not reindex:
|
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||||
errors.append(Path(abs_path).name)
|
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
|
||||||
continue
|
"""Return a list of indexed file ids, and a list of errors"""
|
||||||
else:
|
if not isinstance(file_paths, list):
|
||||||
to_update.append(Path(abs_path).name)
|
file_paths = [file_paths]
|
||||||
|
|
||||||
to_index.append(abs_path)
|
file_ids: list[str | None] = []
|
||||||
|
errors: list[str | None] = []
|
||||||
if errors:
|
n_files = len(file_paths)
|
||||||
error_files = ", ".join(errors)
|
for idx, file_path in enumerate(file_paths):
|
||||||
if len(error_files) > 100:
|
file_path = Path(file_path)
|
||||||
error_files = error_files[:80] + "..."
|
yield Document(
|
||||||
print(
|
content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}",
|
||||||
"Skip these files already exist. Please rename/remove them or "
|
channel="debug",
|
||||||
f"enable reindex:\n{errors}"
|
|
||||||
)
|
|
||||||
self.warning(
|
|
||||||
"Skip these files already exist. Please rename/remove them or "
|
|
||||||
f"enable reindex:\n{error_files}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if not to_index:
|
try:
|
||||||
return [], []
|
pipeline = self.route(file_path)
|
||||||
|
file_id = yield from pipeline.stream(
|
||||||
# persist the files to storage
|
file_path, reindex=reindex, **kwargs
|
||||||
for path in to_index:
|
|
||||||
shutil.copy(path, filestorage_path / file_to_hash[path])
|
|
||||||
|
|
||||||
# extract the file & prepare record info
|
|
||||||
file_to_source: dict = {}
|
|
||||||
extraction_errors = []
|
|
||||||
nodes = []
|
|
||||||
for file_path, file_hash in file_to_hash.items():
|
|
||||||
if str(Path(file_path).resolve()) not in to_index:
|
|
||||||
continue
|
|
||||||
|
|
||||||
extraction_result = self.file_ingestor(file_path)
|
|
||||||
if not extraction_result:
|
|
||||||
extraction_errors.append(Path(file_path).name)
|
|
||||||
continue
|
|
||||||
nodes.extend(extraction_result)
|
|
||||||
source = Source(
|
|
||||||
name=Path(file_path).name,
|
|
||||||
path=file_hash,
|
|
||||||
size=Path(file_path).stat().st_size,
|
|
||||||
user=self._user_id, # type: ignore
|
|
||||||
)
|
)
|
||||||
file_to_source[file_path] = source
|
file_ids.append(file_id)
|
||||||
|
errors.append(None)
|
||||||
if extraction_errors:
|
yield Document(
|
||||||
msg = "Failed to extract these files: {}".format(
|
content={"file_path": file_path, "status": "success"},
|
||||||
", ".join(extraction_errors)
|
channel="index",
|
||||||
)
|
)
|
||||||
print(msg)
|
except Exception as e:
|
||||||
self.warning(msg)
|
logger.error(e)
|
||||||
|
file_ids.append(None)
|
||||||
if not nodes:
|
errors.append(str(e))
|
||||||
return [], []
|
yield Document(
|
||||||
|
content={
|
||||||
print(
|
"file_path": file_path,
|
||||||
"Extracted",
|
"status": "failed",
|
||||||
len(to_index) - len(extraction_errors),
|
"message": str(e),
|
||||||
"files into",
|
|
||||||
len(nodes),
|
|
||||||
"nodes",
|
|
||||||
)
|
|
||||||
|
|
||||||
# index the files
|
|
||||||
print("Indexing the files into vector store")
|
|
||||||
self.indexing_vector_pipeline(nodes)
|
|
||||||
print("Finishing indexing the files into vector store")
|
|
||||||
|
|
||||||
# persist to the index
|
|
||||||
print("Persisting the vector and the document into index")
|
|
||||||
file_ids = []
|
|
||||||
to_update = list(set(to_update))
|
|
||||||
with Session(engine) as session:
|
|
||||||
if to_update:
|
|
||||||
session.execute(delete(Source).where(Source.name.in_(to_update)))
|
|
||||||
|
|
||||||
for source in file_to_source.values():
|
|
||||||
session.add(source)
|
|
||||||
session.commit()
|
|
||||||
for source in file_to_source.values():
|
|
||||||
file_ids.append(source.id)
|
|
||||||
|
|
||||||
for node in nodes:
|
|
||||||
file_path = str(node.metadata["file_path"])
|
|
||||||
node.source = str(file_to_source[file_path].id)
|
|
||||||
file_to_source[file_path].text_length += len(node.text)
|
|
||||||
|
|
||||||
session.flush()
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
with Session(engine) as session:
|
|
||||||
for node in nodes:
|
|
||||||
index = Index(
|
|
||||||
source_id=node.source,
|
|
||||||
target_id=node.doc_id,
|
|
||||||
relation_type="document",
|
|
||||||
)
|
|
||||||
session.add(index)
|
|
||||||
for node in nodes:
|
|
||||||
index = Index(
|
|
||||||
source_id=node.source,
|
|
||||||
target_id=node.doc_id,
|
|
||||||
relation_type="vector",
|
|
||||||
)
|
|
||||||
session.add(index)
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
print("Finishing persisting the vector and the document into index")
|
|
||||||
print(f"{len(nodes)} nodes are indexed")
|
|
||||||
return nodes, file_ids
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_user_settings(cls) -> dict:
|
|
||||||
return {
|
|
||||||
"index_parser": {
|
|
||||||
"name": "Index parser",
|
|
||||||
"value": "normal",
|
|
||||||
"choices": [
|
|
||||||
("PDF text parser", "normal"),
|
|
||||||
("Mathpix", "mathpix"),
|
|
||||||
("Advanced ocr", "ocr"),
|
|
||||||
("Multimodal parser", "multimodal"),
|
|
||||||
],
|
|
||||||
"component": "dropdown",
|
|
||||||
},
|
},
|
||||||
}
|
channel="index",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
return file_ids, errors
|
||||||
def get_pipeline(cls, user_settings, index_settings) -> "IndexDocumentPipeline":
|
|
||||||
"""Get the pipeline based on the setting"""
|
|
||||||
obj = cls()
|
|
||||||
obj.file_ingestor.pdf_mode = user_settings["index_parser"]
|
|
||||||
|
|
||||||
file_extractors, chunk_size, chunk_overlap = dev_settings()
|
|
||||||
if file_extractors:
|
|
||||||
obj.file_ingestor.override_file_extractors = file_extractors
|
|
||||||
if chunk_size:
|
|
||||||
obj.file_ingestor.text_splitter.chunk_size = chunk_size
|
|
||||||
if chunk_overlap:
|
|
||||||
obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap
|
|
||||||
|
|
||||||
obj.indexing_vector_pipeline.embedding = embedding_models_manager[
|
|
||||||
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
|
||||||
]
|
|
||||||
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def set_resources(self, resources: dict):
|
|
||||||
super().set_resources(resources)
|
|
||||||
self.indexing_vector_pipeline.vector_store = self._VS
|
|
||||||
self.indexing_vector_pipeline.doc_store = self._DS
|
|
||||||
|
|
||||||
def warning(self, msg):
|
|
||||||
gr.Warning(msg)
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
@ -63,9 +64,6 @@ class DirectoryUpload(BasePage):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.upload_button = gr.Button("Upload and Index")
|
self.upload_button = gr.Button("Upload and Index")
|
||||||
self.file_output = gr.File(
|
|
||||||
visible=False, label="Output files (debug purpose)"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FileIndexPage(BasePage):
|
class FileIndexPage(BasePage):
|
||||||
|
@ -127,11 +125,23 @@ class FileIndexPage(BasePage):
|
||||||
self.upload_button = gr.Button(
|
self.upload_button = gr.Button(
|
||||||
"Upload and Index", variant="primary"
|
"Upload and Index", variant="primary"
|
||||||
)
|
)
|
||||||
self.file_output = gr.File(
|
|
||||||
visible=False, label="Output files (debug purpose)"
|
|
||||||
)
|
|
||||||
|
|
||||||
with gr.Column(scale=4):
|
with gr.Column(scale=4):
|
||||||
|
with gr.Column(visible=False) as self.upload_progress_panel:
|
||||||
|
gr.Markdown("## Upload Progress")
|
||||||
|
with gr.Row():
|
||||||
|
self.upload_result = gr.Textbox(
|
||||||
|
lines=1, max_lines=20, label="Upload result"
|
||||||
|
)
|
||||||
|
self.upload_info = gr.Textbox(
|
||||||
|
lines=1, max_lines=20, label="Upload info"
|
||||||
|
)
|
||||||
|
self.btn_close_upload_progress_panel = gr.Button(
|
||||||
|
"Clear Upload Info and Close",
|
||||||
|
variant="secondary",
|
||||||
|
elem_classes=["right-button"],
|
||||||
|
)
|
||||||
|
|
||||||
gr.Markdown("## File List")
|
gr.Markdown("## File List")
|
||||||
self.file_list_state = gr.State(value=None)
|
self.file_list_state = gr.State(value=None)
|
||||||
self.file_list = gr.DataFrame(
|
self.file_list = gr.DataFrame(
|
||||||
|
@ -261,6 +271,9 @@ class FileIndexPage(BasePage):
|
||||||
)
|
)
|
||||||
|
|
||||||
onUploaded = self.upload_button.click(
|
onUploaded = self.upload_button.click(
|
||||||
|
fn=lambda: gr.update(visible=True),
|
||||||
|
outputs=[self.upload_progress_panel],
|
||||||
|
).then(
|
||||||
fn=self.index_fn,
|
fn=self.index_fn,
|
||||||
inputs=[
|
inputs=[
|
||||||
self.files,
|
self.files,
|
||||||
|
@ -268,16 +281,28 @@ class FileIndexPage(BasePage):
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self._app.user_id,
|
self._app.user_id,
|
||||||
],
|
],
|
||||||
outputs=[self.file_output],
|
outputs=[self.upload_result, self.upload_info],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
).then(
|
)
|
||||||
|
|
||||||
|
uploadedEvent = onUploaded.then(
|
||||||
fn=self.list_file,
|
fn=self.list_file,
|
||||||
inputs=[self._app.user_id],
|
inputs=[self._app.user_id],
|
||||||
outputs=[self.file_list_state, self.file_list],
|
outputs=[self.file_list_state, self.file_list],
|
||||||
concurrency_limit=20,
|
concurrency_limit=20,
|
||||||
)
|
)
|
||||||
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
||||||
onUploaded = onUploaded.then(**event)
|
uploadedEvent = uploadedEvent.then(**event)
|
||||||
|
|
||||||
|
_ = onUploaded.success(
|
||||||
|
fn=lambda: None,
|
||||||
|
outputs=[self.files],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.btn_close_upload_progress_panel.click(
|
||||||
|
fn=lambda: (gr.update(visible=False), "", ""),
|
||||||
|
outputs=[self.upload_progress_panel, self.upload_result, self.upload_info],
|
||||||
|
)
|
||||||
|
|
||||||
self.file_list.select(
|
self.file_list.select(
|
||||||
fn=self.interact_file_list,
|
fn=self.interact_file_list,
|
||||||
|
@ -294,7 +319,9 @@ class FileIndexPage(BasePage):
|
||||||
outputs=[self.file_list_state, self.file_list],
|
outputs=[self.file_list_state, self.file_list],
|
||||||
)
|
)
|
||||||
|
|
||||||
def index_fn(self, files, reindex: bool, settings, user_id):
|
def index_fn(
|
||||||
|
self, files, reindex: bool, settings, user_id
|
||||||
|
) -> Generator[tuple[str, str], None, None]:
|
||||||
"""Upload and index the files
|
"""Upload and index the files
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -305,35 +332,56 @@ class FileIndexPage(BasePage):
|
||||||
"""
|
"""
|
||||||
if not files:
|
if not files:
|
||||||
gr.Info("No uploaded file")
|
gr.Info("No uploaded file")
|
||||||
return gr.update()
|
yield "", ""
|
||||||
|
return
|
||||||
|
|
||||||
errors = self.validate(files)
|
errors = self.validate(files)
|
||||||
if errors:
|
if errors:
|
||||||
gr.Warning(", ".join(errors))
|
gr.Warning(", ".join(errors))
|
||||||
return gr.update()
|
yield "", ""
|
||||||
|
return
|
||||||
|
|
||||||
gr.Info(f"Start indexing {len(files)} files...")
|
gr.Info(f"Start indexing {len(files)} files...")
|
||||||
|
|
||||||
# get the pipeline
|
# get the pipeline
|
||||||
indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id)
|
indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id)
|
||||||
|
|
||||||
result = indexing_pipeline(files, reindex=reindex)
|
outputs, debugs = [], []
|
||||||
if result is None:
|
# stream the output
|
||||||
gr.Info("Finish indexing")
|
output_stream = indexing_pipeline.stream(files, reindex=reindex)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
response = next(output_stream)
|
||||||
|
if response is None:
|
||||||
|
continue
|
||||||
|
if response.channel == "index":
|
||||||
|
if response.content["status"] == "success":
|
||||||
|
outputs.append(f"\u2705 | {response.content['file_path'].name}")
|
||||||
|
elif response.content["status"] == "failed":
|
||||||
|
outputs.append(
|
||||||
|
f"\u274c | {response.content['file_path'].name}: "
|
||||||
|
f"{response.content['message']}"
|
||||||
|
)
|
||||||
|
elif response.channel == "debug":
|
||||||
|
debugs.append(response.text)
|
||||||
|
yield "\n".join(outputs), "\n".join(debugs)
|
||||||
|
except StopIteration as e:
|
||||||
|
result, errors = e.value
|
||||||
|
except Exception as e:
|
||||||
|
debugs.append(f"Error: {e}")
|
||||||
|
yield "\n".join(outputs), "\n".join(debugs)
|
||||||
return
|
return
|
||||||
output_nodes, _ = result
|
|
||||||
gr.Info(f"Finish indexing into {len(output_nodes)} chunks")
|
|
||||||
|
|
||||||
# download the file
|
n_successes = len([_ for _ in result if _])
|
||||||
text = "\n\n".join([each.text for each in output_nodes])
|
if n_successes:
|
||||||
handler, file_path = tempfile.mkstemp(suffix=".txt")
|
gr.Info(f"Successfully index {n_successes} files")
|
||||||
with open(file_path, "w", encoding="utf-8") as f:
|
n_errors = len([_ for _ in errors if _])
|
||||||
f.write(text)
|
if n_errors:
|
||||||
os.close(handler)
|
gr.Warning(f"Have errors for {n_errors} files")
|
||||||
|
|
||||||
return gr.update(value=file_path, visible=True)
|
def index_files_from_dir(
|
||||||
|
self, folder_path, reindex, settings, user_id
|
||||||
def index_files_from_dir(self, folder_path, reindex, settings, user_id):
|
) -> Generator[tuple[str, str], None, None]:
|
||||||
"""This should be constructable by users
|
"""This should be constructable by users
|
||||||
|
|
||||||
It means that the users can build their own index.
|
It means that the users can build their own index.
|
||||||
|
@ -363,6 +411,7 @@ class FileIndexPage(BasePage):
|
||||||
2. Implement the transformation from artifacts to UI
|
2. Implement the transformation from artifacts to UI
|
||||||
"""
|
"""
|
||||||
if not folder_path:
|
if not folder_path:
|
||||||
|
yield "", ""
|
||||||
return
|
return
|
||||||
|
|
||||||
import fnmatch
|
import fnmatch
|
||||||
|
@ -401,7 +450,7 @@ class FileIndexPage(BasePage):
|
||||||
for p in exclude_patterns:
|
for p in exclude_patterns:
|
||||||
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
|
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
|
||||||
|
|
||||||
return self.index_fn(files, reindex, settings, user_id)
|
yield from self.index_fn(files, reindex, settings, user_id)
|
||||||
|
|
||||||
def list_file(self, user_id):
|
def list_file(self, user_id):
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
|
|
|
@ -99,6 +99,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self.chat_state,
|
self.chat_state,
|
||||||
|
self._app.user_id,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -127,6 +128,7 @@ class ChatPage(BasePage):
|
||||||
self.chat_panel.chatbot,
|
self.chat_panel.chatbot,
|
||||||
self._app.settings_state,
|
self._app.settings_state,
|
||||||
self.chat_state,
|
self.chat_state,
|
||||||
|
self._app.user_id,
|
||||||
]
|
]
|
||||||
+ self._indices_input,
|
+ self._indices_input,
|
||||||
outputs=[
|
outputs=[
|
||||||
|
@ -360,7 +362,7 @@ class ChatPage(BasePage):
|
||||||
session.add(result)
|
session.add(result)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
def create_pipeline(self, settings: dict, state: dict, *selecteds):
|
def create_pipeline(self, settings: dict, state: dict, user_id: int, *selecteds):
|
||||||
"""Create the pipeline from settings
|
"""Create the pipeline from settings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
@ -385,7 +387,9 @@ class ChatPage(BasePage):
|
||||||
if isinstance(index.selector, tuple):
|
if isinstance(index.selector, tuple):
|
||||||
for i in index.selector:
|
for i in index.selector:
|
||||||
index_selected.append(selecteds[i])
|
index_selected.append(selecteds[i])
|
||||||
iretrievers = index.get_retriever_pipelines(settings, index_selected)
|
iretrievers = index.get_retriever_pipelines(
|
||||||
|
settings, user_id, index_selected
|
||||||
|
)
|
||||||
retrievers += iretrievers
|
retrievers += iretrievers
|
||||||
|
|
||||||
# prepare states
|
# prepare states
|
||||||
|
@ -398,7 +402,9 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
return pipeline, reasoning_state
|
return pipeline, reasoning_state
|
||||||
|
|
||||||
def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
def chat_fn(
|
||||||
|
self, conversation_id, chat_history, settings, state, user_id, *selecteds
|
||||||
|
):
|
||||||
"""Chat function"""
|
"""Chat function"""
|
||||||
chat_input = chat_history[-1][0]
|
chat_input = chat_history[-1][0]
|
||||||
chat_history = chat_history[:-1]
|
chat_history = chat_history[:-1]
|
||||||
|
@ -406,7 +412,9 @@ class ChatPage(BasePage):
|
||||||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||||
|
|
||||||
# construct the pipeline
|
# construct the pipeline
|
||||||
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
pipeline, reasoning_state = self.create_pipeline(
|
||||||
|
settings, state, user_id, *selecteds
|
||||||
|
)
|
||||||
pipeline.set_output_queue(queue)
|
pipeline.set_output_queue(queue)
|
||||||
|
|
||||||
text, refs = "", ""
|
text, refs = "", ""
|
||||||
|
@ -452,7 +460,9 @@ class ChatPage(BasePage):
|
||||||
print(f"Generate nothing: {empty_msg}")
|
print(f"Generate nothing: {empty_msg}")
|
||||||
yield chat_history + [(chat_input, text or empty_msg)], refs, state
|
yield chat_history + [(chat_input, text or empty_msg)], refs, state
|
||||||
|
|
||||||
def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
def regen_fn(
|
||||||
|
self, conversation_id, chat_history, settings, state, user_id, *selecteds
|
||||||
|
):
|
||||||
"""Regen function"""
|
"""Regen function"""
|
||||||
if not chat_history:
|
if not chat_history:
|
||||||
gr.Warning("Empty chat")
|
gr.Warning("Empty chat")
|
||||||
|
@ -461,7 +471,7 @@ class ChatPage(BasePage):
|
||||||
|
|
||||||
state["app"]["regen"] = True
|
state["app"]["regen"] = True
|
||||||
for chat, refs, state in self.chat_fn(
|
for chat, refs, state in self.chat_fn(
|
||||||
conversation_id, chat_history, settings, state, *selecteds
|
conversation_id, chat_history, settings, state, user_id, *selecteds
|
||||||
):
|
):
|
||||||
new_state = deepcopy(state)
|
new_state = deepcopy(state)
|
||||||
new_state["app"]["regen"] = False
|
new_state["app"]["regen"] = False
|
||||||
|
|
Loading…
Reference in New Issue
Block a user