(pump:minor) Allow the indexing pipeline to report the indexing progress onto the UI (#81)
* Turn the file indexing event to generator to report progress * Fix React text's trimming function * Refactor delete file into a method
This commit is contained in:
parent
56dfc8fb53
commit
ebf1315569
|
@ -1,5 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, Iterator, Optional
|
||||
from typing import Any, AsyncGenerator, Iterator, Optional
|
||||
|
||||
from theflow import Function, Node, Param, lazy
|
||||
|
||||
|
@ -58,7 +58,7 @@ class BaseComponent(Function):
|
|||
@abstractmethod
|
||||
def run(
|
||||
self, *args, **kwargs
|
||||
) -> Document | list[Document] | Iterator[Document] | None:
|
||||
) -> Document | list[Document] | Iterator[Document] | None | Any:
|
||||
"""Run the component."""
|
||||
...
|
||||
|
||||
|
|
|
@ -32,12 +32,13 @@ class Document(BaseDocument):
|
|||
channel: the channel to show the document. Optional.:
|
||||
- chat: show in chat message
|
||||
- info: show in information panel
|
||||
- index: show in index panel
|
||||
- debug: show in debug panel
|
||||
"""
|
||||
|
||||
content: Any = None
|
||||
source: Optional[str] = None
|
||||
channel: Optional[Literal["chat", "info", "debug"]] = None
|
||||
channel: Optional[Literal["chat", "info", "index", "debug"]] = None
|
||||
|
||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||
if content is None:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
from llama_index.readers import PDFReader
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Param
|
||||
|
@ -17,18 +18,20 @@ from kotaemon.loaders import (
|
|||
UnstructuredReader,
|
||||
)
|
||||
|
||||
KH_DEFAULT_FILE_EXTRACTORS: dict[str, Type[BaseReader]] = {
|
||||
".xlsx": PandasExcelReader,
|
||||
".docx": UnstructuredReader,
|
||||
".xls": UnstructuredReader,
|
||||
".doc": UnstructuredReader,
|
||||
".html": HtmlReader,
|
||||
".mhtml": MhtmlReader,
|
||||
".png": UnstructuredReader,
|
||||
".jpeg": UnstructuredReader,
|
||||
".jpg": UnstructuredReader,
|
||||
".tiff": UnstructuredReader,
|
||||
".tif": UnstructuredReader,
|
||||
unstructured = UnstructuredReader()
|
||||
KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
".docx": unstructured,
|
||||
".xls": unstructured,
|
||||
".doc": unstructured,
|
||||
".html": HtmlReader(),
|
||||
".mhtml": MhtmlReader(),
|
||||
".png": unstructured,
|
||||
".jpeg": unstructured,
|
||||
".jpg": unstructured,
|
||||
".tiff": unstructured,
|
||||
".tif": unstructured,
|
||||
".pdf": PDFReader(),
|
||||
}
|
||||
|
||||
|
||||
|
@ -64,7 +67,7 @@ class DocumentIngestor(BaseComponent):
|
|||
def _get_reader(self, input_files: list[str | Path]):
|
||||
"""Get appropriate readers for the input files based on file extension"""
|
||||
file_extractors: dict[str, BaseReader] = {
|
||||
ext: cls() for ext, cls in KH_DEFAULT_FILE_EXTRACTORS.items()
|
||||
ext: reader for ext, reader in KH_DEFAULT_FILE_EXTRACTORS.items()
|
||||
}
|
||||
for ext, cls in self.override_file_extractors.items():
|
||||
file_extractors[ext] = cls()
|
||||
|
|
|
@ -8,6 +8,8 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class BaseReader(BaseComponent):
|
||||
"""The base class for all readers"""
|
||||
|
||||
...
|
||||
|
||||
|
||||
|
|
|
@ -126,7 +126,7 @@ class BaseIndex(abc.ABC):
|
|||
...
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, selected: Any = None
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseComponent"]:
|
||||
"""Return the retriever pipelines to retrieve the entity from the index"""
|
||||
return []
|
||||
|
|
|
@ -1,10 +1,18 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Generator, Optional
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base import BaseComponent, Document, Param
|
||||
|
||||
|
||||
class BaseFileIndexRetriever(BaseComponent):
|
||||
|
||||
Source = Param(help="The SQLAlchemy Source table")
|
||||
Index = Param(help="The SQLAlchemy Index table")
|
||||
VS = Param(help="The VectorStore")
|
||||
DS = Param(help="The DocStore")
|
||||
FSPath = Param(help="The file storage path")
|
||||
user_id = Param(help="The user id")
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
"""Get the user settings for indexing
|
||||
|
@ -24,20 +32,6 @@ class BaseFileIndexRetriever(BaseComponent):
|
|||
) -> "BaseFileIndexRetriever":
|
||||
raise NotImplementedError
|
||||
|
||||
def set_resources(self, resources: dict):
|
||||
"""Set the resources for the indexing pipeline
|
||||
|
||||
This will setup the tables, the vector store and docstore.
|
||||
|
||||
Args:
|
||||
resources (dict): the resources for the indexing pipeline
|
||||
"""
|
||||
self._Source = resources["Source"]
|
||||
self._Index = resources["Index"]
|
||||
self._VS = resources["VectorStore"]
|
||||
self._DS = resources["DocStore"]
|
||||
self._fs_path = resources["FileStoragePath"]
|
||||
|
||||
|
||||
class BaseFileIndexIndexing(BaseComponent):
|
||||
"""The pipeline to index information into the data store
|
||||
|
@ -54,11 +48,45 @@ class BaseFileIndexIndexing(BaseComponent):
|
|||
- self._DS: the docstore
|
||||
"""
|
||||
|
||||
def run(self, file_paths: str | Path | list[str | Path], *args, **kwargs):
|
||||
Source = Param(help="The SQLAlchemy Source table")
|
||||
Index = Param(help="The SQLAlchemy Index table")
|
||||
VS = Param(help="The VectorStore")
|
||||
DS = Param(help="The DocStore")
|
||||
FSPath = Param(help="The file storage path")
|
||||
user_id = Param(help="The user id")
|
||||
|
||||
def run(
|
||||
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||
) -> tuple[list[str | None], list[str | None]]:
|
||||
"""Run the indexing pipeline
|
||||
|
||||
Args:
|
||||
file_paths (str | Path | list[str | Path]): the file paths to index
|
||||
|
||||
Returns:
|
||||
- the indexed file ids (each file id corresponds to an input file path, or
|
||||
None if the indexing failed for that file path)
|
||||
- the error messages (each error message corresponds to an input file path,
|
||||
or None if the indexing was successful for that file path)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def stream(
|
||||
self, file_paths: str | Path | list[str | Path], *args, **kwargs
|
||||
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
|
||||
"""Stream the indexing pipeline
|
||||
|
||||
Args:
|
||||
file_paths (str | Path | list[str | Path]): the file paths to index
|
||||
|
||||
Yields:
|
||||
Document: the output message to the UI, must have channel == index or debug
|
||||
|
||||
Returns:
|
||||
- the indexed file ids (each file id corresponds to an input file path, or
|
||||
None if the indexing failed for that file path)
|
||||
- the error messages (each error message corresponds to an input file path,
|
||||
or None if the indexing was successful for that file path)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -78,20 +106,6 @@ class BaseFileIndexIndexing(BaseComponent):
|
|||
"""
|
||||
return {}
|
||||
|
||||
def set_resources(self, resources: dict):
|
||||
"""Set the resources for the indexing pipeline
|
||||
|
||||
This will setup the tables, the vector store and docstore.
|
||||
|
||||
Args:
|
||||
resources (dict): the resources for the indexing pipeline
|
||||
"""
|
||||
self._Source = resources["Source"]
|
||||
self._Index = resources["Index"]
|
||||
self._VS = resources["VectorStore"]
|
||||
self._DS = resources["DocStore"]
|
||||
self._fs_path = resources["FileStoragePath"]
|
||||
|
||||
def copy_to_filestorage(
|
||||
self, file_paths: str | Path | list[str | Path]
|
||||
) -> list[str]:
|
||||
|
@ -113,7 +127,7 @@ class BaseFileIndexIndexing(BaseComponent):
|
|||
for file_path in file_paths:
|
||||
with open(file_path, "rb") as f:
|
||||
paths.append(sha256(f.read()).hexdigest())
|
||||
shutil.copy(file_path, self._fs_path / paths[-1])
|
||||
shutil.copy(file_path, self.FSPath / paths[-1])
|
||||
|
||||
return paths
|
||||
|
||||
|
|
|
@ -362,13 +362,17 @@ class FileIndex(BaseIndex):
|
|||
stripped_settings[key] = value
|
||||
|
||||
obj = self._indexing_pipeline_cls.get_pipeline(stripped_settings, self.config)
|
||||
obj.set_resources(resources=self._resources)
|
||||
obj._user_id = user_id
|
||||
obj.Source = self._resources["Source"]
|
||||
obj.Index = self._resources["Index"]
|
||||
obj.VS = self._vs
|
||||
obj.DS = self._docstore
|
||||
obj.FSPath = self._fs_path
|
||||
obj.user_id = user_id
|
||||
|
||||
return obj
|
||||
|
||||
def get_retriever_pipelines(
|
||||
self, settings: dict, selected: Any = None
|
||||
self, settings: dict, user_id: int, selected: Any = None
|
||||
) -> list["BaseFileIndexRetriever"]:
|
||||
# retrieval settings
|
||||
prefix = f"index.options.{self.id}."
|
||||
|
@ -387,7 +391,12 @@ class FileIndex(BaseIndex):
|
|||
obj = cls.get_pipeline(stripped_settings, self.config, selected_ids)
|
||||
if obj is None:
|
||||
continue
|
||||
obj.set_resources(self._resources)
|
||||
obj.Source = self._resources["Source"]
|
||||
obj.Index = self._resources["Index"]
|
||||
obj.VS = self._vs
|
||||
obj.DS = self._docstore
|
||||
obj.FSPath = self._fs_path
|
||||
obj.user_id = user_id
|
||||
retrievers.append(obj)
|
||||
|
||||
return retrievers
|
||||
|
|
|
@ -7,13 +7,13 @@ from collections import defaultdict
|
|||
from functools import lru_cache
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Generator, Optional
|
||||
|
||||
import gradio as gr
|
||||
from ktem.components import filestorage_path
|
||||
from ktem.db.models import engine
|
||||
from ktem.embeddings.manager import embedding_models_manager
|
||||
from ktem.llms.manager import llms
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.readers.file.base import default_file_metadata_func
|
||||
from llama_index.vector_stores import (
|
||||
FilterCondition,
|
||||
FilterOperator,
|
||||
|
@ -26,10 +26,12 @@ from sqlalchemy.orm import Session
|
|||
from theflow.settings import settings
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from kotaemon.base import RetrievedDocument
|
||||
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
|
||||
from kotaemon.embeddings import BaseEmbeddings
|
||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.indices.ingests import DocumentIngestor
|
||||
from kotaemon.indices.ingests.files import KH_DEFAULT_FILE_EXTRACTORS
|
||||
from kotaemon.indices.rankings import BaseReranking, LLMReranking
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
|
||||
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
|
||||
|
||||
|
@ -43,7 +45,7 @@ def dev_settings():
|
|||
|
||||
if hasattr(settings, "FILE_INDEX_PIPELINE_FILE_EXTRACTORS"):
|
||||
file_extractors = {
|
||||
key: import_dotted_string(value, safe=False)
|
||||
key: import_dotted_string(value, safe=False)()
|
||||
for key, value in settings.FILE_INDEX_PIPELINE_FILE_EXTRACTORS.items()
|
||||
}
|
||||
|
||||
|
@ -72,12 +74,20 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
mmr: whether to use mmr to re-rank the documents
|
||||
"""
|
||||
|
||||
vector_retrieval: VectorRetrieval = VectorRetrieval.withx()
|
||||
embedding: BaseEmbeddings
|
||||
reranker: BaseReranking = LLMReranking.withx()
|
||||
get_extra_table: bool = False
|
||||
mmr: bool = False
|
||||
top_k: int = 5
|
||||
|
||||
@Node.auto(depends_on=["embedding", "VS", "DS"])
|
||||
def vector_retrieval(self) -> VectorRetrieval:
|
||||
return VectorRetrieval(
|
||||
embedding=self.embedding,
|
||||
vector_store=self.VS,
|
||||
doc_store=self.DS,
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
text: str,
|
||||
|
@ -95,13 +105,11 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
logger.info(f"Skip retrieval because of no selected files: {self}")
|
||||
return []
|
||||
|
||||
Index = self._Index
|
||||
|
||||
retrieval_kwargs = {}
|
||||
with Session(engine) as session:
|
||||
stmt = select(Index).where(
|
||||
Index.relation_type == "vector",
|
||||
Index.source_id.in_(doc_ids), # type: ignore
|
||||
stmt = select(self.Index).where(
|
||||
self.Index.relation_type == "vector",
|
||||
self.Index.source_id.in_(doc_ids),
|
||||
)
|
||||
results = session.execute(stmt)
|
||||
vs_ids = [r[0].target_id for r in results.all()]
|
||||
|
@ -186,7 +194,7 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
"component": "dropdown",
|
||||
},
|
||||
"num_retrieval": {
|
||||
"name": "Number of documents to retrieve",
|
||||
"name": "Number of document chunks to retrieve",
|
||||
"value": 3,
|
||||
"component": "number",
|
||||
},
|
||||
|
@ -228,6 +236,11 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
get_extra_table=user_settings["prioritize_table"],
|
||||
top_k=user_settings["num_retrieval"],
|
||||
mmr=user_settings["mmr"],
|
||||
embedding=embedding_models_manager[
|
||||
index_settings.get(
|
||||
"embedding", embedding_models_manager.get_default_name()
|
||||
)
|
||||
],
|
||||
)
|
||||
if not user_settings["use_reranking"]:
|
||||
retriever.reranker = None # type: ignore
|
||||
|
@ -236,226 +249,346 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
|
|||
user_settings["reranking_llm"], llms.get_default()
|
||||
)
|
||||
|
||||
retriever.vector_retrieval.embedding = embedding_models_manager[
|
||||
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
||||
]
|
||||
kwargs = {".doc_ids": selected}
|
||||
retriever.set_run(kwargs, temp=True)
|
||||
return retriever
|
||||
|
||||
def set_resources(self, resources: dict):
|
||||
super().set_resources(resources)
|
||||
self.vector_retrieval.vector_store = self._VS
|
||||
self.vector_retrieval.doc_store = self._DS
|
||||
|
||||
class IndexPipeline(BaseComponent):
|
||||
"""Index a single file"""
|
||||
|
||||
loader: BaseReader
|
||||
splitter: BaseSplitter
|
||||
chunk_batch_size: int = 50
|
||||
|
||||
Source = Param(help="The SQLAlchemy Source table")
|
||||
Index = Param(help="The SQLAlchemy Index table")
|
||||
VS = Param(help="The VectorStore")
|
||||
DS = Param(help="The DocStore")
|
||||
FSPath = Param(help="The file storage path")
|
||||
user_id = Param(help="The user id")
|
||||
embedding: BaseEmbeddings
|
||||
|
||||
@Node.auto(depends_on=["Source", "Index", "embedding"])
|
||||
def vector_indexing(self) -> VectorIndexing:
|
||||
return VectorIndexing(
|
||||
vector_store=self.VS, doc_store=self.DS, embedding=self.embedding
|
||||
)
|
||||
|
||||
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
|
||||
chunks = []
|
||||
n_chunks = 0
|
||||
for cidx, chunk in enumerate(self.splitter(docs)):
|
||||
chunks.append(chunk)
|
||||
if cidx % self.chunk_batch_size == 0:
|
||||
self.handle_chunks(chunks, file_id)
|
||||
n_chunks += len(chunks)
|
||||
chunks = []
|
||||
yield Document(
|
||||
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
|
||||
)
|
||||
|
||||
if chunks:
|
||||
self.handle_chunks(chunks, file_id)
|
||||
n_chunks += len(chunks)
|
||||
yield Document(
|
||||
f" => [{file_name}] Processed {n_chunks} chunks", channel="debug"
|
||||
)
|
||||
|
||||
return n_chunks
|
||||
|
||||
def handle_chunks(self, chunks, file_id):
|
||||
"""Run chunks"""
|
||||
# run embedding, add to both vector store and doc store
|
||||
self.vector_indexing(chunks)
|
||||
|
||||
# record in the index
|
||||
with Session(engine) as session:
|
||||
nodes = []
|
||||
for chunk in chunks:
|
||||
nodes.append(
|
||||
self.Index(
|
||||
source_id=file_id,
|
||||
target_id=chunk.doc_id,
|
||||
relation_type="document",
|
||||
)
|
||||
)
|
||||
nodes.append(
|
||||
self.Index(
|
||||
source_id=file_id,
|
||||
target_id=chunk.doc_id,
|
||||
relation_type="vector",
|
||||
)
|
||||
)
|
||||
session.add_all(nodes)
|
||||
session.commit()
|
||||
|
||||
def get_id_if_exists(self, file_path: Path) -> Optional[str]:
|
||||
"""Check if the file is already indexed
|
||||
|
||||
Args:
|
||||
file_path: the path to the file
|
||||
|
||||
Returns:
|
||||
the file id if the file is indexed, otherwise None
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
stmt = select(self.Source).where(self.Source.name == file_path.name)
|
||||
item = session.execute(stmt).first()
|
||||
if item:
|
||||
return item[0].id
|
||||
|
||||
return None
|
||||
|
||||
def store_file(self, file_path: Path) -> str:
|
||||
"""Store file into the database and storage, return the file id
|
||||
|
||||
Args:
|
||||
file_path: the path to the file
|
||||
|
||||
Returns:
|
||||
the file id
|
||||
"""
|
||||
with file_path.open("rb") as fi:
|
||||
file_hash = sha256(fi.read()).hexdigest()
|
||||
|
||||
shutil.copy(file_path, self.FSPath / file_hash)
|
||||
source = self.Source(
|
||||
name=file_path.name,
|
||||
path=file_hash,
|
||||
size=file_path.stat().st_size,
|
||||
user=self.user_id, # type: ignore
|
||||
)
|
||||
with Session(engine) as session:
|
||||
session.add(source)
|
||||
session.commit()
|
||||
file_id = source.id
|
||||
|
||||
return file_id
|
||||
|
||||
def finish(self, file_id: str, file_path: Path) -> str:
|
||||
"""Finish the indexing"""
|
||||
with Session(engine) as session:
|
||||
stmt = select(self.Index.target_id).where(self.Index.source_id == file_id)
|
||||
doc_ids = [_[0] for _ in session.execute(stmt)]
|
||||
if doc_ids:
|
||||
docs = self.DS.get(doc_ids)
|
||||
stmt = select(self.Source).where(self.Source.id == file_id)
|
||||
result = session.execute(stmt).first()
|
||||
if result:
|
||||
item = result[0]
|
||||
item.text_length = sum([len(doc.text) for doc in docs])
|
||||
session.add(item)
|
||||
session.commit()
|
||||
|
||||
return file_id
|
||||
|
||||
def delete_file(self, file_id: str):
|
||||
"""Delete a file from the db, including its chunks in docstore and vectorstore
|
||||
|
||||
Args:
|
||||
file_id: the file id
|
||||
"""
|
||||
with Session(engine) as session:
|
||||
session.execute(delete(self.Source).where(self.Source.id == file_id))
|
||||
vs_ids, ds_ids = [], []
|
||||
index = session.execute(
|
||||
select(self.Index).where(self.Index.source_id == file_id)
|
||||
).all()
|
||||
for each in index:
|
||||
if each[0].relation_type == "vector":
|
||||
vs_ids.append(each[0].target_id)
|
||||
else:
|
||||
ds_ids.append(each[0].target_id)
|
||||
session.delete(each[0])
|
||||
session.commit()
|
||||
self.VS.delete(vs_ids)
|
||||
self.DS.delete(ds_ids)
|
||||
|
||||
def run(self, file_path: str | Path, reindex: bool, **kwargs) -> str:
|
||||
"""Index the file and return the file id"""
|
||||
# check for duplication
|
||||
file_path = Path(file_path).resolve()
|
||||
file_id = self.get_id_if_exists(file_path)
|
||||
if file_id is not None:
|
||||
if not reindex:
|
||||
raise ValueError(
|
||||
f"File {file_path.name} already indexed. Please rerun with "
|
||||
"reindex=True to force reindexing."
|
||||
)
|
||||
else:
|
||||
# remove the existing records
|
||||
self.delete_file(file_id)
|
||||
file_id = self.store_file(file_path)
|
||||
else:
|
||||
# add record to db
|
||||
file_id = self.store_file(file_path)
|
||||
|
||||
# extract the file
|
||||
extra_info = default_file_metadata_func(str(file_path))
|
||||
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
||||
for _ in self.handle_docs(docs, file_id, file_path.name):
|
||||
continue
|
||||
self.finish(file_id, file_path)
|
||||
|
||||
return file_id
|
||||
|
||||
def stream(
|
||||
self, file_path: str | Path, reindex: bool, **kwargs
|
||||
) -> Generator[Document, None, str]:
|
||||
# check for duplication
|
||||
file_path = Path(file_path).resolve()
|
||||
file_id = self.get_id_if_exists(file_path)
|
||||
if file_id is not None:
|
||||
if not reindex:
|
||||
raise ValueError(
|
||||
f"File {file_path.name} already indexed. Please rerun with "
|
||||
"reindex=True to force reindexing."
|
||||
)
|
||||
else:
|
||||
# remove the existing records
|
||||
yield Document(f" => Removing old {file_path.name}", channel="debug")
|
||||
self.delete_file(file_id)
|
||||
file_id = self.store_file(file_path)
|
||||
else:
|
||||
# add record to db
|
||||
file_id = self.store_file(file_path)
|
||||
|
||||
# extract the file
|
||||
extra_info = default_file_metadata_func(str(file_path))
|
||||
yield Document(f" => Converting {file_path.name} to text", channel="debug")
|
||||
docs = self.loader.load_data(file_path, extra_info=extra_info)
|
||||
yield Document(f" => Converted {file_path.name} to text", channel="debug")
|
||||
yield from self.handle_docs(docs, file_id, file_path.name)
|
||||
|
||||
self.finish(file_id, file_path)
|
||||
|
||||
yield Document(f" => Finished indexing {file_path.name}", channel="debug")
|
||||
return file_id
|
||||
|
||||
|
||||
class IndexDocumentPipeline(BaseFileIndexIndexing):
|
||||
"""Store the documents and index the content into vector store and doc store
|
||||
"""Index the file. Decide which pipeline based on the file type.
|
||||
|
||||
Args:
|
||||
indexing_vector_pipeline: pipeline to index the documents
|
||||
file_ingestor: ingestor to ingest the documents
|
||||
This method is essentially a factory to decide which indexing pipeline to use.
|
||||
|
||||
We can decide the pipeline programmatically, and/or automatically based on an LLM.
|
||||
If we based on the LLM, essentially we will log the LLM thought process in a file,
|
||||
and then during the indexing, we will read that file to decide which pipeline
|
||||
to use, and then log the operation in that file. Overtime, the LLM can learn to
|
||||
decide which pipeline should be used.
|
||||
"""
|
||||
|
||||
indexing_vector_pipeline: VectorIndexing = VectorIndexing.withx()
|
||||
file_ingestor: DocumentIngestor = DocumentIngestor.withx()
|
||||
embedding: BaseEmbeddings
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
|
||||
obj = cls(
|
||||
embedding=embedding_models_manager[
|
||||
index_settings.get(
|
||||
"embedding", embedding_models_manager.get_default_name()
|
||||
)
|
||||
]
|
||||
)
|
||||
return obj
|
||||
|
||||
def route(self, file_path: Path) -> IndexPipeline:
|
||||
"""Decide the pipeline based on the file type
|
||||
|
||||
Can subclass this method for a more elaborate pipeline routing strategy.
|
||||
"""
|
||||
readers, chunk_size, chunk_overlap = dev_settings()
|
||||
|
||||
ext = file_path.suffix
|
||||
reader = readers.get(ext, KH_DEFAULT_FILE_EXTRACTORS.get(ext, None))
|
||||
if reader is None:
|
||||
raise NotImplementedError(
|
||||
f"No supported pipeline to index {file_path.name}. Please specify "
|
||||
"the suitable pipeline for this file type in the settings."
|
||||
)
|
||||
|
||||
pipeline: IndexPipeline = IndexPipeline(
|
||||
loader=reader,
|
||||
splitter=TokenSplitter(
|
||||
chunk_size=chunk_size or 1024,
|
||||
chunk_overlap=chunk_overlap or 256,
|
||||
separator="\n\n",
|
||||
backup_separators=["\n", ".", "\u200B"],
|
||||
),
|
||||
Source=self.Source,
|
||||
Index=self.Index,
|
||||
VS=self.VS,
|
||||
DS=self.DS,
|
||||
FSPath=self.FSPath,
|
||||
user_id=self.user_id,
|
||||
embedding=self.embedding,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
def run(
|
||||
self,
|
||||
file_paths: str | Path | list[str | Path],
|
||||
reindex: bool = False,
|
||||
**kwargs, # type: ignore
|
||||
):
|
||||
"""Index the list of documents
|
||||
|
||||
This function will extract the files, persist the files to storage,
|
||||
index the files.
|
||||
|
||||
Args:
|
||||
file_paths: list of file paths to index
|
||||
reindex: whether to force reindexing the files if they exist
|
||||
|
||||
Returns:
|
||||
list of split nodes
|
||||
"""
|
||||
Source = self._Source
|
||||
Index = self._Index
|
||||
|
||||
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||
) -> tuple[list[str | None], list[str | None]]:
|
||||
"""Return a list of indexed file ids, and a list of errors"""
|
||||
if not isinstance(file_paths, list):
|
||||
file_paths = [file_paths]
|
||||
|
||||
to_index: list[str] = []
|
||||
file_to_hash: dict[str, str] = {}
|
||||
errors = []
|
||||
to_update = []
|
||||
|
||||
file_ids: list[str | None] = []
|
||||
errors: list[str | None] = []
|
||||
for file_path in file_paths:
|
||||
abs_path = str(Path(file_path).resolve())
|
||||
with open(abs_path, "rb") as fi:
|
||||
file_hash = sha256(fi.read()).hexdigest()
|
||||
file_path = Path(file_path)
|
||||
|
||||
file_to_hash[abs_path] = file_hash
|
||||
try:
|
||||
pipeline = self.route(file_path)
|
||||
file_id = pipeline.run(file_path, reindex=reindex, **kwargs)
|
||||
file_ids.append(file_id)
|
||||
errors.append(None)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
file_ids.append(None)
|
||||
errors.append(str(e))
|
||||
|
||||
with Session(engine) as session:
|
||||
statement = select(Source).where(Source.name == Path(abs_path).name)
|
||||
item = session.execute(statement).first()
|
||||
return file_ids, errors
|
||||
|
||||
if item:
|
||||
if not reindex:
|
||||
errors.append(Path(abs_path).name)
|
||||
continue
|
||||
else:
|
||||
to_update.append(Path(abs_path).name)
|
||||
def stream(
|
||||
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
|
||||
) -> Generator[Document, None, tuple[list[str | None], list[str | None]]]:
|
||||
"""Return a list of indexed file ids, and a list of errors"""
|
||||
if not isinstance(file_paths, list):
|
||||
file_paths = [file_paths]
|
||||
|
||||
to_index.append(abs_path)
|
||||
|
||||
if errors:
|
||||
error_files = ", ".join(errors)
|
||||
if len(error_files) > 100:
|
||||
error_files = error_files[:80] + "..."
|
||||
print(
|
||||
"Skip these files already exist. Please rename/remove them or "
|
||||
f"enable reindex:\n{errors}"
|
||||
)
|
||||
self.warning(
|
||||
"Skip these files already exist. Please rename/remove them or "
|
||||
f"enable reindex:\n{error_files}"
|
||||
file_ids: list[str | None] = []
|
||||
errors: list[str | None] = []
|
||||
n_files = len(file_paths)
|
||||
for idx, file_path in enumerate(file_paths):
|
||||
file_path = Path(file_path)
|
||||
yield Document(
|
||||
content=f"Indexing [{idx+1}/{n_files}]: {file_path.name}",
|
||||
channel="debug",
|
||||
)
|
||||
|
||||
if not to_index:
|
||||
return [], []
|
||||
|
||||
# persist the files to storage
|
||||
for path in to_index:
|
||||
shutil.copy(path, filestorage_path / file_to_hash[path])
|
||||
|
||||
# extract the file & prepare record info
|
||||
file_to_source: dict = {}
|
||||
extraction_errors = []
|
||||
nodes = []
|
||||
for file_path, file_hash in file_to_hash.items():
|
||||
if str(Path(file_path).resolve()) not in to_index:
|
||||
continue
|
||||
|
||||
extraction_result = self.file_ingestor(file_path)
|
||||
if not extraction_result:
|
||||
extraction_errors.append(Path(file_path).name)
|
||||
continue
|
||||
nodes.extend(extraction_result)
|
||||
source = Source(
|
||||
name=Path(file_path).name,
|
||||
path=file_hash,
|
||||
size=Path(file_path).stat().st_size,
|
||||
user=self._user_id, # type: ignore
|
||||
try:
|
||||
pipeline = self.route(file_path)
|
||||
file_id = yield from pipeline.stream(
|
||||
file_path, reindex=reindex, **kwargs
|
||||
)
|
||||
file_to_source[file_path] = source
|
||||
|
||||
if extraction_errors:
|
||||
msg = "Failed to extract these files: {}".format(
|
||||
", ".join(extraction_errors)
|
||||
file_ids.append(file_id)
|
||||
errors.append(None)
|
||||
yield Document(
|
||||
content={"file_path": file_path, "status": "success"},
|
||||
channel="index",
|
||||
)
|
||||
print(msg)
|
||||
self.warning(msg)
|
||||
|
||||
if not nodes:
|
||||
return [], []
|
||||
|
||||
print(
|
||||
"Extracted",
|
||||
len(to_index) - len(extraction_errors),
|
||||
"files into",
|
||||
len(nodes),
|
||||
"nodes",
|
||||
)
|
||||
|
||||
# index the files
|
||||
print("Indexing the files into vector store")
|
||||
self.indexing_vector_pipeline(nodes)
|
||||
print("Finishing indexing the files into vector store")
|
||||
|
||||
# persist to the index
|
||||
print("Persisting the vector and the document into index")
|
||||
file_ids = []
|
||||
to_update = list(set(to_update))
|
||||
with Session(engine) as session:
|
||||
if to_update:
|
||||
session.execute(delete(Source).where(Source.name.in_(to_update)))
|
||||
|
||||
for source in file_to_source.values():
|
||||
session.add(source)
|
||||
session.commit()
|
||||
for source in file_to_source.values():
|
||||
file_ids.append(source.id)
|
||||
|
||||
for node in nodes:
|
||||
file_path = str(node.metadata["file_path"])
|
||||
node.source = str(file_to_source[file_path].id)
|
||||
file_to_source[file_path].text_length += len(node.text)
|
||||
|
||||
session.flush()
|
||||
session.commit()
|
||||
|
||||
with Session(engine) as session:
|
||||
for node in nodes:
|
||||
index = Index(
|
||||
source_id=node.source,
|
||||
target_id=node.doc_id,
|
||||
relation_type="document",
|
||||
)
|
||||
session.add(index)
|
||||
for node in nodes:
|
||||
index = Index(
|
||||
source_id=node.source,
|
||||
target_id=node.doc_id,
|
||||
relation_type="vector",
|
||||
)
|
||||
session.add(index)
|
||||
session.commit()
|
||||
|
||||
print("Finishing persisting the vector and the document into index")
|
||||
print(f"{len(nodes)} nodes are indexed")
|
||||
return nodes, file_ids
|
||||
|
||||
@classmethod
|
||||
def get_user_settings(cls) -> dict:
|
||||
return {
|
||||
"index_parser": {
|
||||
"name": "Index parser",
|
||||
"value": "normal",
|
||||
"choices": [
|
||||
("PDF text parser", "normal"),
|
||||
("Mathpix", "mathpix"),
|
||||
("Advanced ocr", "ocr"),
|
||||
("Multimodal parser", "multimodal"),
|
||||
],
|
||||
"component": "dropdown",
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
file_ids.append(None)
|
||||
errors.append(str(e))
|
||||
yield Document(
|
||||
content={
|
||||
"file_path": file_path,
|
||||
"status": "failed",
|
||||
"message": str(e),
|
||||
},
|
||||
}
|
||||
channel="index",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_pipeline(cls, user_settings, index_settings) -> "IndexDocumentPipeline":
|
||||
"""Get the pipeline based on the setting"""
|
||||
obj = cls()
|
||||
obj.file_ingestor.pdf_mode = user_settings["index_parser"]
|
||||
|
||||
file_extractors, chunk_size, chunk_overlap = dev_settings()
|
||||
if file_extractors:
|
||||
obj.file_ingestor.override_file_extractors = file_extractors
|
||||
if chunk_size:
|
||||
obj.file_ingestor.text_splitter.chunk_size = chunk_size
|
||||
if chunk_overlap:
|
||||
obj.file_ingestor.text_splitter.chunk_overlap = chunk_overlap
|
||||
|
||||
obj.indexing_vector_pipeline.embedding = embedding_models_manager[
|
||||
index_settings.get("embedding", embedding_models_manager.get_default_name())
|
||||
]
|
||||
|
||||
return obj
|
||||
|
||||
def set_resources(self, resources: dict):
|
||||
super().set_resources(resources)
|
||||
self.indexing_vector_pipeline.vector_store = self._VS
|
||||
self.indexing_vector_pipeline.doc_store = self._DS
|
||||
|
||||
def warning(self, msg):
|
||||
gr.Warning(msg)
|
||||
return file_ids, errors
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Generator
|
||||
|
||||
import gradio as gr
|
||||
import pandas as pd
|
||||
|
@ -63,9 +64,6 @@ class DirectoryUpload(BasePage):
|
|||
)
|
||||
|
||||
self.upload_button = gr.Button("Upload and Index")
|
||||
self.file_output = gr.File(
|
||||
visible=False, label="Output files (debug purpose)"
|
||||
)
|
||||
|
||||
|
||||
class FileIndexPage(BasePage):
|
||||
|
@ -127,11 +125,23 @@ class FileIndexPage(BasePage):
|
|||
self.upload_button = gr.Button(
|
||||
"Upload and Index", variant="primary"
|
||||
)
|
||||
self.file_output = gr.File(
|
||||
visible=False, label="Output files (debug purpose)"
|
||||
)
|
||||
|
||||
with gr.Column(scale=4):
|
||||
with gr.Column(visible=False) as self.upload_progress_panel:
|
||||
gr.Markdown("## Upload Progress")
|
||||
with gr.Row():
|
||||
self.upload_result = gr.Textbox(
|
||||
lines=1, max_lines=20, label="Upload result"
|
||||
)
|
||||
self.upload_info = gr.Textbox(
|
||||
lines=1, max_lines=20, label="Upload info"
|
||||
)
|
||||
self.btn_close_upload_progress_panel = gr.Button(
|
||||
"Clear Upload Info and Close",
|
||||
variant="secondary",
|
||||
elem_classes=["right-button"],
|
||||
)
|
||||
|
||||
gr.Markdown("## File List")
|
||||
self.file_list_state = gr.State(value=None)
|
||||
self.file_list = gr.DataFrame(
|
||||
|
@ -261,6 +271,9 @@ class FileIndexPage(BasePage):
|
|||
)
|
||||
|
||||
onUploaded = self.upload_button.click(
|
||||
fn=lambda: gr.update(visible=True),
|
||||
outputs=[self.upload_progress_panel],
|
||||
).then(
|
||||
fn=self.index_fn,
|
||||
inputs=[
|
||||
self.files,
|
||||
|
@ -268,16 +281,28 @@ class FileIndexPage(BasePage):
|
|||
self._app.settings_state,
|
||||
self._app.user_id,
|
||||
],
|
||||
outputs=[self.file_output],
|
||||
outputs=[self.upload_result, self.upload_info],
|
||||
concurrency_limit=20,
|
||||
).then(
|
||||
)
|
||||
|
||||
uploadedEvent = onUploaded.then(
|
||||
fn=self.list_file,
|
||||
inputs=[self._app.user_id],
|
||||
outputs=[self.file_list_state, self.file_list],
|
||||
concurrency_limit=20,
|
||||
)
|
||||
for event in self._app.get_event(f"onFileIndex{self._index.id}Changed"):
|
||||
onUploaded = onUploaded.then(**event)
|
||||
uploadedEvent = uploadedEvent.then(**event)
|
||||
|
||||
_ = onUploaded.success(
|
||||
fn=lambda: None,
|
||||
outputs=[self.files],
|
||||
)
|
||||
|
||||
self.btn_close_upload_progress_panel.click(
|
||||
fn=lambda: (gr.update(visible=False), "", ""),
|
||||
outputs=[self.upload_progress_panel, self.upload_result, self.upload_info],
|
||||
)
|
||||
|
||||
self.file_list.select(
|
||||
fn=self.interact_file_list,
|
||||
|
@ -294,7 +319,9 @@ class FileIndexPage(BasePage):
|
|||
outputs=[self.file_list_state, self.file_list],
|
||||
)
|
||||
|
||||
def index_fn(self, files, reindex: bool, settings, user_id):
|
||||
def index_fn(
|
||||
self, files, reindex: bool, settings, user_id
|
||||
) -> Generator[tuple[str, str], None, None]:
|
||||
"""Upload and index the files
|
||||
|
||||
Args:
|
||||
|
@ -305,35 +332,56 @@ class FileIndexPage(BasePage):
|
|||
"""
|
||||
if not files:
|
||||
gr.Info("No uploaded file")
|
||||
return gr.update()
|
||||
yield "", ""
|
||||
return
|
||||
|
||||
errors = self.validate(files)
|
||||
if errors:
|
||||
gr.Warning(", ".join(errors))
|
||||
return gr.update()
|
||||
yield "", ""
|
||||
return
|
||||
|
||||
gr.Info(f"Start indexing {len(files)} files...")
|
||||
|
||||
# get the pipeline
|
||||
indexing_pipeline = self._index.get_indexing_pipeline(settings, user_id)
|
||||
|
||||
result = indexing_pipeline(files, reindex=reindex)
|
||||
if result is None:
|
||||
gr.Info("Finish indexing")
|
||||
outputs, debugs = [], []
|
||||
# stream the output
|
||||
output_stream = indexing_pipeline.stream(files, reindex=reindex)
|
||||
try:
|
||||
while True:
|
||||
response = next(output_stream)
|
||||
if response is None:
|
||||
continue
|
||||
if response.channel == "index":
|
||||
if response.content["status"] == "success":
|
||||
outputs.append(f"\u2705 | {response.content['file_path'].name}")
|
||||
elif response.content["status"] == "failed":
|
||||
outputs.append(
|
||||
f"\u274c | {response.content['file_path'].name}: "
|
||||
f"{response.content['message']}"
|
||||
)
|
||||
elif response.channel == "debug":
|
||||
debugs.append(response.text)
|
||||
yield "\n".join(outputs), "\n".join(debugs)
|
||||
except StopIteration as e:
|
||||
result, errors = e.value
|
||||
except Exception as e:
|
||||
debugs.append(f"Error: {e}")
|
||||
yield "\n".join(outputs), "\n".join(debugs)
|
||||
return
|
||||
output_nodes, _ = result
|
||||
gr.Info(f"Finish indexing into {len(output_nodes)} chunks")
|
||||
|
||||
# download the file
|
||||
text = "\n\n".join([each.text for each in output_nodes])
|
||||
handler, file_path = tempfile.mkstemp(suffix=".txt")
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
f.write(text)
|
||||
os.close(handler)
|
||||
n_successes = len([_ for _ in result if _])
|
||||
if n_successes:
|
||||
gr.Info(f"Successfully index {n_successes} files")
|
||||
n_errors = len([_ for _ in errors if _])
|
||||
if n_errors:
|
||||
gr.Warning(f"Have errors for {n_errors} files")
|
||||
|
||||
return gr.update(value=file_path, visible=True)
|
||||
|
||||
def index_files_from_dir(self, folder_path, reindex, settings, user_id):
|
||||
def index_files_from_dir(
|
||||
self, folder_path, reindex, settings, user_id
|
||||
) -> Generator[tuple[str, str], None, None]:
|
||||
"""This should be constructable by users
|
||||
|
||||
It means that the users can build their own index.
|
||||
|
@ -363,6 +411,7 @@ class FileIndexPage(BasePage):
|
|||
2. Implement the transformation from artifacts to UI
|
||||
"""
|
||||
if not folder_path:
|
||||
yield "", ""
|
||||
return
|
||||
|
||||
import fnmatch
|
||||
|
@ -401,7 +450,7 @@ class FileIndexPage(BasePage):
|
|||
for p in exclude_patterns:
|
||||
files = [f for f in files if not fnmatch.fnmatch(name=f, pat=p)]
|
||||
|
||||
return self.index_fn(files, reindex, settings, user_id)
|
||||
yield from self.index_fn(files, reindex, settings, user_id)
|
||||
|
||||
def list_file(self, user_id):
|
||||
if user_id is None:
|
||||
|
|
|
@ -99,6 +99,7 @@ class ChatPage(BasePage):
|
|||
self.chat_panel.chatbot,
|
||||
self._app.settings_state,
|
||||
self.chat_state,
|
||||
self._app.user_id,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=[
|
||||
|
@ -127,6 +128,7 @@ class ChatPage(BasePage):
|
|||
self.chat_panel.chatbot,
|
||||
self._app.settings_state,
|
||||
self.chat_state,
|
||||
self._app.user_id,
|
||||
]
|
||||
+ self._indices_input,
|
||||
outputs=[
|
||||
|
@ -360,7 +362,7 @@ class ChatPage(BasePage):
|
|||
session.add(result)
|
||||
session.commit()
|
||||
|
||||
def create_pipeline(self, settings: dict, state: dict, *selecteds):
|
||||
def create_pipeline(self, settings: dict, state: dict, user_id: int, *selecteds):
|
||||
"""Create the pipeline from settings
|
||||
|
||||
Args:
|
||||
|
@ -385,7 +387,9 @@ class ChatPage(BasePage):
|
|||
if isinstance(index.selector, tuple):
|
||||
for i in index.selector:
|
||||
index_selected.append(selecteds[i])
|
||||
iretrievers = index.get_retriever_pipelines(settings, index_selected)
|
||||
iretrievers = index.get_retriever_pipelines(
|
||||
settings, user_id, index_selected
|
||||
)
|
||||
retrievers += iretrievers
|
||||
|
||||
# prepare states
|
||||
|
@ -398,7 +402,9 @@ class ChatPage(BasePage):
|
|||
|
||||
return pipeline, reasoning_state
|
||||
|
||||
def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||
def chat_fn(
|
||||
self, conversation_id, chat_history, settings, state, user_id, *selecteds
|
||||
):
|
||||
"""Chat function"""
|
||||
chat_input = chat_history[-1][0]
|
||||
chat_history = chat_history[:-1]
|
||||
|
@ -406,7 +412,9 @@ class ChatPage(BasePage):
|
|||
queue: asyncio.Queue[Optional[dict]] = asyncio.Queue()
|
||||
|
||||
# construct the pipeline
|
||||
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
|
||||
pipeline, reasoning_state = self.create_pipeline(
|
||||
settings, state, user_id, *selecteds
|
||||
)
|
||||
pipeline.set_output_queue(queue)
|
||||
|
||||
text, refs = "", ""
|
||||
|
@ -452,7 +460,9 @@ class ChatPage(BasePage):
|
|||
print(f"Generate nothing: {empty_msg}")
|
||||
yield chat_history + [(chat_input, text or empty_msg)], refs, state
|
||||
|
||||
def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
|
||||
def regen_fn(
|
||||
self, conversation_id, chat_history, settings, state, user_id, *selecteds
|
||||
):
|
||||
"""Regen function"""
|
||||
if not chat_history:
|
||||
gr.Warning("Empty chat")
|
||||
|
@ -461,7 +471,7 @@ class ChatPage(BasePage):
|
|||
|
||||
state["app"]["regen"] = True
|
||||
for chat, refs, state in self.chat_fn(
|
||||
conversation_id, chat_history, settings, state, *selecteds
|
||||
conversation_id, chat_history, settings, state, user_id, *selecteds
|
||||
):
|
||||
new_state = deepcopy(state)
|
||||
new_state["app"]["regen"] = False
|
||||
|
|
Loading…
Reference in New Issue
Block a user