kotaemon/libs/ktem/ktem/index/file/pipelines.py
Pedro Lima 908452cc18
allow chunk_overlap with zero value (#457) #none
Co-authored-by: Tadashi <tadashi@cinnamon.is>
2024-11-05 14:12:52 +07:00

824 lines
28 KiB
Python

from __future__ import annotations
import json
import logging
import shutil
import threading
import time
import warnings
from collections import defaultdict
from copy import deepcopy
from functools import lru_cache
from hashlib import sha256
from pathlib import Path
from typing import Generator, Optional, Sequence
import tiktoken
from ktem.db.models import engine
from ktem.embeddings.manager import embedding_models_manager
from ktem.llms.manager import llms
from ktem.rerankings.manager import reranking_models_manager
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.file.base import default_file_metadata_func
from llama_index.core.vector_stores import (
FilterCondition,
FilterOperator,
MetadataFilter,
MetadataFilters,
)
from llama_index.core.vector_stores.types import VectorStoreQueryMode
from sqlalchemy import delete, select
from sqlalchemy.orm import Session
from theflow.settings import settings
from theflow.utils.modules import import_dotted_string
from kotaemon.base import BaseComponent, Document, Node, Param, RetrievedDocument
from kotaemon.embeddings import BaseEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.indices.ingests.files import (
KH_DEFAULT_FILE_EXTRACTORS,
adobe_reader,
azure_reader,
unstructured,
web_reader,
)
from kotaemon.indices.rankings import BaseReranking, LLMReranking, LLMTrulensScoring
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from .base import BaseFileIndexIndexing, BaseFileIndexRetriever
logger = logging.getLogger(__name__)
@lru_cache
def dev_settings():
"""Retrieve the developer settings from flowsettings.py"""
file_extractors = {}
if hasattr(settings, "FILE_INDEX_PIPELINE_FILE_EXTRACTORS"):
file_extractors = {
key: import_dotted_string(value, safe=False)()
for key, value in settings.FILE_INDEX_PIPELINE_FILE_EXTRACTORS.items()
}
chunk_size = None
if hasattr(settings, "FILE_INDEX_PIPELINE_SPLITTER_CHUNK_SIZE"):
chunk_size = settings.FILE_INDEX_PIPELINE_SPLITTER_CHUNK_SIZE
chunk_overlap = None
if hasattr(settings, "FILE_INDEX_PIPELINE_SPLITTER_CHUNK_OVERLAP"):
chunk_overlap = settings.FILE_INDEX_PIPELINE_SPLITTER_CHUNK_OVERLAP
return file_extractors, chunk_size, chunk_overlap
_default_token_func = tiktoken.encoding_for_model("gpt-3.5-turbo").encode
class DocumentRetrievalPipeline(BaseFileIndexRetriever):
"""Retrieve relevant document
Args:
vector_retrieval: the retrieval pipeline that return the relevant documents
given a text query
reranker: the reranking pipeline that re-rank and filter the retrieved
documents
get_extra_table: if True, for each retrieved document, the pipeline will look
for surrounding tables (e.g. within the page)
top_k: number of documents to retrieve
mmr: whether to use mmr to re-rank the documents
"""
embedding: BaseEmbeddings
rerankers: Sequence[BaseReranking] = []
# use LLM to create relevant scores for displaying on UI
llm_scorer: LLMReranking | None = LLMReranking.withx()
get_extra_table: bool = False
mmr: bool = False
top_k: int = 5
retrieval_mode: str = "hybrid"
@Node.auto(depends_on=["embedding", "VS", "DS"])
def vector_retrieval(self) -> VectorRetrieval:
return VectorRetrieval(
embedding=self.embedding,
vector_store=self.VS,
doc_store=self.DS,
retrieval_mode=self.retrieval_mode, # type: ignore
rerankers=self.rerankers,
)
def run(
self,
text: str,
doc_ids: Optional[list[str]] = None,
*args,
**kwargs,
) -> list[RetrievedDocument]:
"""Retrieve document excerpts similar to the text
Args:
text: the text to retrieve similar documents
doc_ids: list of document ids to constraint the retrieval
"""
# flatten doc_ids in case of group of doc_ids are passed
if doc_ids:
flatten_doc_ids = []
for doc_id in doc_ids:
if doc_id.startswith("["):
flatten_doc_ids.extend(json.loads(doc_id))
else:
flatten_doc_ids.append(doc_id)
doc_ids = flatten_doc_ids
print("searching in doc_ids", doc_ids)
if not doc_ids:
logger.info(f"Skip retrieval because of no selected files: {self}")
return []
retrieval_kwargs: dict = {}
with Session(engine) as session:
stmt = select(self.Index).where(
self.Index.relation_type == "document",
self.Index.source_id.in_(doc_ids),
)
results = session.execute(stmt)
chunk_ids = [r[0].target_id for r in results.all()]
# do first round top_k extension
retrieval_kwargs["do_extend"] = True
retrieval_kwargs["scope"] = chunk_ids
retrieval_kwargs["filters"] = MetadataFilters(
filters=[
MetadataFilter(
key="file_id",
value=doc_ids,
operator=FilterOperator.IN,
)
],
condition=FilterCondition.OR,
)
if self.mmr:
# TODO: double check that llama-index MMR works correctly
retrieval_kwargs["mode"] = VectorStoreQueryMode.MMR
retrieval_kwargs["mmr_threshold"] = 0.5
# rerank
s_time = time.time()
print(f"retrieval_kwargs: {retrieval_kwargs.keys()}")
docs = self.vector_retrieval(text=text, top_k=self.top_k, **retrieval_kwargs)
print("retrieval step took", time.time() - s_time)
if not self.get_extra_table:
return docs
# retrieve extra nodes relate to table
table_pages = defaultdict(list)
retrieved_id = set([doc.doc_id for doc in docs])
for doc in docs:
if "page_label" not in doc.metadata:
continue
if "file_name" not in doc.metadata:
warnings.warn(
"file_name not in metadata while page_label is in metadata: "
f"{doc.metadata}"
)
table_pages[doc.metadata["file_name"]].append(doc.metadata["page_label"])
queries: list[dict] = [
{"$and": [{"file_name": {"$eq": fn}}, {"page_label": {"$in": pls}}]}
for fn, pls in table_pages.items()
]
if queries:
try:
extra_docs = self.vector_retrieval(
text="",
top_k=50,
where=queries[0] if len(queries) == 1 else {"$or": queries},
)
for doc in extra_docs:
if doc.doc_id not in retrieved_id:
docs.append(doc)
except Exception:
print("Error retrieving additional tables")
return docs
def generate_relevant_scores(
self, query: str, documents: list[RetrievedDocument]
) -> list[RetrievedDocument]:
docs = (
documents
if not self.llm_scorer
else self.llm_scorer(documents=documents, query=query)
)
return docs
@classmethod
def get_user_settings(cls) -> dict:
from ktem.llms.manager import llms
try:
reranking_llm = llms.get_default_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
reranking_llm = None
reranking_llm_choices = []
return {
"reranking_llm": {
"name": "LLM for relevant scoring",
"value": reranking_llm,
"component": "dropdown",
"choices": reranking_llm_choices,
"special_type": "llm",
},
"num_retrieval": {
"name": "Number of document chunks to retrieve",
"value": 10,
"component": "number",
},
"retrieval_mode": {
"name": "Retrieval mode",
"value": "hybrid",
"choices": ["vector", "text", "hybrid"],
"component": "dropdown",
},
"prioritize_table": {
"name": "Prioritize table",
"value": False,
"choices": [True, False],
"component": "checkbox",
},
"mmr": {
"name": "Use MMR",
"value": False,
"choices": [True, False],
"component": "checkbox",
},
"use_reranking": {
"name": "Use reranking",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
"use_llm_reranking": {
"name": "Use LLM relevant scoring",
"value": True,
"choices": [True, False],
"component": "checkbox",
},
}
@classmethod
def get_pipeline(cls, user_settings, index_settings, selected):
"""Get retriever objects associated with the index
Args:
settings: the settings of the app
kwargs: other arguments
"""
use_llm_reranking = user_settings.get("use_llm_reranking", False)
retriever = cls(
get_extra_table=user_settings["prioritize_table"],
top_k=user_settings["num_retrieval"],
mmr=user_settings["mmr"],
embedding=embedding_models_manager[
index_settings.get(
"embedding", embedding_models_manager.get_default_name()
)
],
retrieval_mode=user_settings["retrieval_mode"],
llm_scorer=(LLMTrulensScoring() if use_llm_reranking else None),
rerankers=[
reranking_models_manager[
index_settings.get(
"reranking", reranking_models_manager.get_default_name()
)
]
],
)
if not user_settings["use_reranking"]:
retriever.rerankers = [] # type: ignore
for reranker in retriever.rerankers:
if isinstance(reranker, LLMReranking):
reranker.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
if retriever.llm_scorer:
retriever.llm_scorer.llm = llms.get(
user_settings["reranking_llm"], llms.get_default()
)
kwargs = {".doc_ids": selected}
retriever.set_run(kwargs, temp=False)
return retriever
class IndexPipeline(BaseComponent):
"""Index a single file"""
loader: BaseReader
splitter: BaseSplitter | None
chunk_batch_size: int = 200
Source = Param(help="The SQLAlchemy Source table")
Index = Param(help="The SQLAlchemy Index table")
VS = Param(help="The VectorStore")
DS = Param(help="The DocStore")
FSPath = Param(help="The file storage path")
user_id = Param(help="The user id")
collection_name: str = "default"
private: bool = False
run_embedding_in_thread: bool = False
embedding: BaseEmbeddings
@Node.auto(depends_on=["Source", "Index", "embedding"])
def vector_indexing(self) -> VectorIndexing:
return VectorIndexing(
vector_store=self.VS, doc_store=self.DS, embedding=self.embedding
)
def handle_docs(self, docs, file_id, file_name) -> Generator[Document, None, int]:
s_time = time.time()
text_docs = []
non_text_docs = []
thumbnail_docs = []
for doc in docs:
doc_type = doc.metadata.get("type", "text")
if doc_type == "text":
text_docs.append(doc)
elif doc_type == "thumbnail":
thumbnail_docs.append(doc)
else:
non_text_docs.append(doc)
print(f"Got {len(thumbnail_docs)} page thumbnails")
page_label_to_thumbnail = {
doc.metadata["page_label"]: doc.doc_id for doc in thumbnail_docs
}
if self.splitter:
all_chunks = self.splitter(text_docs)
else:
all_chunks = text_docs
# add the thumbnails doc_id to the chunks
for chunk in all_chunks:
page_label = chunk.metadata.get("page_label", None)
if page_label and page_label in page_label_to_thumbnail:
chunk.metadata["thumbnail_doc_id"] = page_label_to_thumbnail[page_label]
to_index_chunks = all_chunks + non_text_docs + thumbnail_docs
# add to doc store
chunks = []
n_chunks = 0
chunk_size = self.chunk_batch_size * 4
for start_idx in range(0, len(to_index_chunks), chunk_size):
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
self.handle_chunks_docstore(chunks, file_id)
n_chunks += len(chunks)
yield Document(
f" => [{file_name}] Processed {n_chunks} chunks",
channel="debug",
)
def insert_chunks_to_vectorstore():
chunks = []
n_chunks = 0
chunk_size = self.chunk_batch_size
for start_idx in range(0, len(to_index_chunks), chunk_size):
chunks = to_index_chunks[start_idx : start_idx + chunk_size]
self.handle_chunks_vectorstore(chunks, file_id)
n_chunks += len(chunks)
if self.VS:
yield Document(
f" => [{file_name}] Created embedding for {n_chunks} chunks",
channel="debug",
)
# run vector indexing in thread if specified
if self.run_embedding_in_thread:
print("Running embedding in thread")
threading.Thread(
target=lambda: list(insert_chunks_to_vectorstore())
).start()
else:
yield from insert_chunks_to_vectorstore()
print("indexing step took", time.time() - s_time)
return n_chunks
def handle_chunks_docstore(self, chunks, file_id):
"""Run chunks"""
# run embedding, add to both vector store and doc store
self.vector_indexing.add_to_docstore(chunks)
# record in the index
with Session(engine) as session:
nodes = []
for chunk in chunks:
nodes.append(
self.Index(
source_id=file_id,
target_id=chunk.doc_id,
relation_type="document",
)
)
session.add_all(nodes)
session.commit()
def handle_chunks_vectorstore(self, chunks, file_id):
"""Run chunks"""
# run embedding, add to both vector store and doc store
self.vector_indexing.add_to_vectorstore(chunks)
self.vector_indexing.write_chunk_to_file(chunks)
if self.VS:
# record in the index
with Session(engine) as session:
nodes = []
for chunk in chunks:
nodes.append(
self.Index(
source_id=file_id,
target_id=chunk.doc_id,
relation_type="vector",
)
)
session.add_all(nodes)
session.commit()
def get_id_if_exists(self, file_path: str | Path) -> Optional[str]:
"""Check if the file is already indexed
Args:
file_path: the path to the file
Returns:
the file id if the file is indexed, otherwise None
"""
file_name = file_path.name if isinstance(file_path, Path) else file_path
if self.private:
cond: tuple = (
self.Source.name == file_name,
self.Source.user == self.user_id,
)
else:
cond = (self.Source.name == file_name,)
with Session(engine) as session:
stmt = select(self.Source).where(*cond)
item = session.execute(stmt).first()
if item:
return item[0].id
return None
def store_url(self, url: str) -> str:
"""Store URL into the database and storage, return the file id
Args:
url: the URL
Returns:
the file id
"""
file_hash = sha256(url.encode()).hexdigest()
source = self.Source(
name=url,
path=file_hash,
size=0,
user=self.user_id, # type: ignore
)
with Session(engine) as session:
session.add(source)
session.commit()
file_id = source.id
return file_id
def store_file(self, file_path: Path) -> str:
"""Store file into the database and storage, return the file id
Args:
file_path: the path to the file
Returns:
the file id
"""
with file_path.open("rb") as fi:
file_hash = sha256(fi.read()).hexdigest()
shutil.copy(file_path, self.FSPath / file_hash)
source = self.Source(
name=file_path.name,
path=file_hash,
size=file_path.stat().st_size,
user=self.user_id, # type: ignore
)
with Session(engine) as session:
session.add(source)
session.commit()
file_id = source.id
return file_id
def finish(self, file_id: str, file_path: str | Path) -> str:
"""Finish the indexing"""
with Session(engine) as session:
stmt = select(self.Source).where(self.Source.id == file_id)
result = session.execute(stmt).first()
if not result:
return file_id
item = result[0]
# populate the number of tokens
doc_ids_stmt = select(self.Index.target_id).where(
self.Index.source_id == file_id,
self.Index.relation_type == "document",
)
doc_ids = [_[0] for _ in session.execute(doc_ids_stmt)]
token_func = self.get_token_func()
if doc_ids and token_func:
docs = self.DS.get(doc_ids)
item.note["tokens"] = sum([len(token_func(doc.text)) for doc in docs])
# populate the note
item.note["loader"] = self.get_from_path("loader").__class__.__name__
session.add(item)
session.commit()
return file_id
def get_token_func(self):
"""Get the token function for calculating the number of tokens"""
return _default_token_func
def delete_file(self, file_id: str):
"""Delete a file from the db, including its chunks in docstore and vectorstore
Args:
file_id: the file id
"""
with Session(engine) as session:
session.execute(delete(self.Source).where(self.Source.id == file_id))
vs_ids, ds_ids = [], []
index = session.execute(
select(self.Index).where(self.Index.source_id == file_id)
).all()
for each in index:
if each[0].relation_type == "vector":
vs_ids.append(each[0].target_id)
elif each[0].relation_type == "document":
ds_ids.append(each[0].target_id)
session.delete(each[0])
session.commit()
if vs_ids and self.VS:
self.VS.delete(vs_ids)
if ds_ids:
self.DS.delete(ds_ids)
def run(
self, file_path: str | Path, reindex: bool, **kwargs
) -> tuple[str, list[Document]]:
raise NotImplementedError
def stream(
self, file_path: str | Path, reindex: bool, **kwargs
) -> Generator[Document, None, tuple[str, list[Document]]]:
# check if the file is already indexed
if isinstance(file_path, Path):
file_path = file_path.resolve()
file_id = self.get_id_if_exists(file_path)
if isinstance(file_path, Path):
if file_id is not None:
if not reindex:
raise ValueError(
f"File {file_path.name} already indexed. Please rerun with "
"reindex=True to force reindexing."
)
else:
# remove the existing records
yield Document(
f" => Removing old {file_path.name}", channel="debug"
)
self.delete_file(file_id)
file_id = self.store_file(file_path)
else:
# add record to db
file_id = self.store_file(file_path)
else:
if file_id is not None:
raise ValueError(f"URL {file_path} already indexed.")
else:
# add record to db
file_id = self.store_url(file_path)
# extract the file
if isinstance(file_path, Path):
extra_info = default_file_metadata_func(str(file_path))
file_name = file_path.name
else:
extra_info = {"file_name": file_path}
file_name = file_path
extra_info["file_id"] = file_id
extra_info["collection_name"] = self.collection_name
yield Document(f" => Converting {file_name} to text", channel="debug")
docs = self.loader.load_data(file_path, extra_info=extra_info)
yield Document(f" => Converted {file_name} to text", channel="debug")
yield from self.handle_docs(docs, file_id, file_name)
self.finish(file_id, file_path)
yield Document(f" => Finished indexing {file_name}", channel="debug")
return file_id, docs
class IndexDocumentPipeline(BaseFileIndexIndexing):
"""Index the file. Decide which pipeline based on the file type.
This method is essentially a factory to decide which indexing pipeline to use.
We can decide the pipeline programmatically, and/or automatically based on an LLM.
If we based on the LLM, essentially we will log the LLM thought process in a file,
and then during the indexing, we will read that file to decide which pipeline
to use, and then log the operation in that file. Overtime, the LLM can learn to
decide which pipeline should be used.
"""
reader_mode: str = Param("default", help="The reader mode")
embedding: BaseEmbeddings
run_embedding_in_thread: bool = False
@Param.auto(depends_on="reader_mode")
def readers(self):
readers = deepcopy(KH_DEFAULT_FILE_EXTRACTORS)
print("reader_mode", self.reader_mode)
if self.reader_mode == "adobe":
readers[".pdf"] = adobe_reader
elif self.reader_mode == "azure-di":
readers[".pdf"] = azure_reader
dev_readers, _, _ = dev_settings()
readers.update(dev_readers)
return readers
@classmethod
def get_user_settings(cls):
return {
"reader_mode": {
"name": "File loader",
"value": "default",
"choices": [
("Default (open-source)", "default"),
("Adobe API (figure+table extraction)", "adobe"),
(
"Azure AI Document Intelligence (figure+table extraction)",
"azure-di",
),
],
"component": "dropdown",
},
}
@classmethod
def get_pipeline(cls, user_settings, index_settings) -> BaseFileIndexIndexing:
use_quick_index_mode = user_settings.get("quick_index_mode", False)
print("use_quick_index_mode", use_quick_index_mode)
obj = cls(
embedding=embedding_models_manager[
index_settings.get(
"embedding", embedding_models_manager.get_default_name()
)
],
run_embedding_in_thread=use_quick_index_mode,
reader_mode=user_settings.get("reader_mode", "default"),
)
return obj
def is_url(self, file_path: str | Path) -> bool:
return isinstance(file_path, str) and (
file_path.startswith("http://") or file_path.startswith("https://")
)
def route(self, file_path: str | Path) -> IndexPipeline:
"""Decide the pipeline based on the file type
Can subclass this method for a more elaborate pipeline routing strategy.
"""
_, chunk_size, chunk_overlap = dev_settings()
# check if file_path is a URL
if self.is_url(file_path):
reader = web_reader
else:
assert isinstance(file_path, Path)
ext = file_path.suffix.lower()
reader = self.readers.get(ext, unstructured)
if reader is None:
raise NotImplementedError(
f"No supported pipeline to index {file_path.name}. Please specify "
"the suitable pipeline for this file type in the settings."
)
print("Using reader", reader)
pipeline: IndexPipeline = IndexPipeline(
loader=reader,
splitter=TokenSplitter(
chunk_size=chunk_size or 1024,
chunk_overlap=chunk_overlap if chunk_overlap is not None else 256,
separator="\n\n",
backup_separators=["\n", ".", "\u200B"],
),
run_embedding_in_thread=self.run_embedding_in_thread,
Source=self.Source,
Index=self.Index,
VS=self.VS,
DS=self.DS,
FSPath=self.FSPath,
user_id=self.user_id,
private=self.private,
embedding=self.embedding,
)
return pipeline
def run(
self, file_paths: str | Path | list[str | Path], *args, **kwargs
) -> tuple[list[str | None], list[str | None]]:
raise NotImplementedError
def stream(
self, file_paths: str | Path | list[str | Path], reindex: bool = False, **kwargs
) -> Generator[
Document, None, tuple[list[str | None], list[str | None], list[Document]]
]:
"""Return a list of indexed file ids, and a list of errors"""
if not isinstance(file_paths, list):
file_paths = [file_paths]
file_ids: list[str | None] = []
errors: list[str | None] = []
all_docs = []
n_files = len(file_paths)
for idx, file_path in enumerate(file_paths):
if self.is_url(file_path):
file_name = file_path
else:
file_path = Path(file_path)
file_name = file_path.name
yield Document(
content=f"Indexing [{idx + 1}/{n_files}]: {file_name}",
channel="debug",
)
try:
pipeline = self.route(file_path)
file_id, docs = yield from pipeline.stream(
file_path, reindex=reindex, **kwargs
)
all_docs.extend(docs)
file_ids.append(file_id)
errors.append(None)
yield Document(
content={
"file_path": file_path,
"file_name": file_name,
"status": "success",
},
channel="index",
)
except Exception as e:
logger.exception(e)
file_ids.append(None)
errors.append(str(e))
yield Document(
content={
"file_path": file_path,
"file_name": file_name,
"status": "failed",
"message": str(e),
},
channel="index",
)
return file_ids, errors, all_docs