feat: merge develop (#123)
* Support hybrid vector retrieval * Enable figures and table reading in Azure DI * Retrieve with multi-modal * Fix mixing up table * Add txt loader * Add Anthropic Chat * Raising error when retrieving help file * Allow same filename for different people if private is True * Allow declaring extra LLM vendors * Show chunks on the File page * Allow elasticsearch to get more docs * Fix Cohere response (#86) * Fix Cohere response * Remove Adobe pdfservice from dependency kotaemon doesn't rely more pdfservice for its core functionality, and pdfservice uses very out-dated dependency that causes conflict. --------- Co-authored-by: trducng <trungduc1992@gmail.com> * Add confidence score (#87) * Save question answering data as a log file * Save the original information besides the rewritten info * Export Cohere relevance score as confidence score * Fix style check * Upgrade the confidence score appearance (#90) * Highlight the relevance score * Round relevance score. Get key from config instead of env * Cohere return all scores * Display relevance score for image * Remove columns and rows in Excel loader which contains all NaN (#91) * remove columns and rows which contains all NaN * back to multiple joiner options * Fix style --------- Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local> Co-authored-by: trducng <trungduc1992@gmail.com> * Track retriever state * Bump llama-index version 0.10 * feat/save-azuredi-mhtml-to-markdown (#93) * feat/save-azuredi-mhtml-to-markdown * fix: replace os.path to pathlib change theflow.settings * refactor: base on pre-commit * chore: move the func of saving content markdown above removed_spans --------- Co-authored-by: jacky0218 <jacky0218@github.com> * fix: losing first chunk (#94) * fix: losing first chunk. * fix: update the method of preventing losing chunks --------- Co-authored-by: jacky0218 <jacky0218@github.com> * fix: adding the base64 image in markdown (#95) * feat: more chunk info on UI * fix: error when reindexing files * refactor: allow more information exception trace when using gpt4v * feat: add excel reader that treats each worksheet as a document * Persist loader information when indexing file * feat: allow hiding unneeded setting panels * feat: allow specific timezone when creating conversation * feat: add more confidence score (#96) * Allow a list of rerankers * Export llm reranking score instead of filter with boolean * Get logprobs from LLMs * Rename cohere reranking score * Call 2 rerankers at once * Run QA pipeline for each chunk to get qa_score * Display more relevance scores * Define another LLMScoring instead of editing the original one * Export logprobs instead of probs * Call LLMScoring * Get qa_score only in the final answer * feat: replace text length with token in file list * ui: show index name instead of id in the settings * feat(ai): restrict the vision temperature * fix(ui): remove the misleading message about non-retrieved evidences * feat(ui): show the reasoning name and description in the reasoning setting page * feat(ui): show version on the main windows * feat(ui): show default llm name in the setting page * fix(conf): append the result of doc in llm_scoring (#97) * fix: constraint maximum number of images * feat(ui): allow filter file by name in file list page * Fix exceeding token length error for OpenAI embeddings by chunking then averaging (#99) * Average embeddings in case the text exceeds max size * Add docstring * fix: Allow empty string when calling embedding * fix: update trulens LLM ranking score for retrieval confidence, improve citation (#98) * Round when displaying not by default * Add LLMTrulens reranking model * Use llmtrulensscoring in pipeline * fix: update UI display for trulen score --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * feat: add question decomposition & few-shot rewrite pipeline (#89) * Create few-shot query-rewriting. Run and display the result in info_panel * Fix style check * Put the functions to separate modules * Add zero-shot question decomposition * Fix fewshot rewriting * Add default few-shot examples * Fix decompose question * Fix importing rewriting pipelines * fix: update decompose logic in fullQA pipeline --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * fix: add encoding utf-8 when save temporal markdown in vectorIndex (#101) * fix: improve retrieval pipeline and relevant score display (#102) * fix: improve retrieval pipeline by extending first round top_k with multiplier * fix: minor fix * feat: improve UI default settings and add quick switch option for pipeline * fix: improve agent logics (#103) * fix: improve agent progres display * fix: update retrieval logic * fix: UI display * fix: less verbose debug log * feat: add warning message for low confidence * fix: LLM scoring enabled by default * fix: minor update logics * fix: hotfix image citation * feat: update docx loader for handle merged table cells + handle zip file upload (#104) * feat: update docx loader for handle merged table cells * feat: handle zip file * refactor: pre-commit * fix: escape text in download UI * feat: optimize vector store query db (#105) * feat: optimize vector store query db * feat: add file_id to chroma metadatas * feat: remove unnecessary logs and update migrate script * feat: iterate through file index * fix: remove unused code --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * fix: add openai embedidng exponential back-off * fix: update import download_loader * refactor: codespell * fix: update some default settings * fix: update installation instruction * fix: default chunk length in simple QA * feat: add share converstation feature and enable retrieval history (#108) * feat: add share converstation feature and enable retrieval history * fix: update share conversation UI --------- Co-authored-by: taprosoft <tadashi@cinnamon.is> * fix: allow exponential backoff for failed OCR call (#109) * fix: update default prompt when no retrieval is used * fix: create embedding for long image chunks * fix: add exception handling for additional table retriever * fix: clean conversation & file selection UI * fix: elastic search with empty doc_ids * feat: add thumbnail PDF reader for quick multimodal QA * feat: add thumbnail handling logic in indexing * fix: UI text update * fix: PDF thumb loader page number logic * feat: add quick indexing pipeline and update UI * feat: add conv name suggestion * fix: minor UI change * feat: citation in thread * fix: add conv name suggestion in regen * chore: add assets for usage doc * chore: update usage doc * feat: pdf viewer (#110) * feat: update pdfviewer * feat: update missing files * fix: update rendering logic of infor panel * fix: improve thumbnail retrieval logic * fix: update PDF evidence rendering logic * fix: remove pdfjs built dist * fix: reduce thumbnail evidence count * chore: update gitignore * fix: add js event on chat msg select * fix: update css for viewer * fix: add env var for PDFJS prebuilt * fix: move language setting to reasoning utils --------- Co-authored-by: phv2312 <kat87yb@gmail.com> Co-authored-by: trducng <trungduc1992@gmail.com> * feat: graph rag (#116) * fix: reload server when add/delete index * fix: rework indexing pipeline to be able to disable vectorstore and splitter if needed * feat: add graphRAG index with plot view * fix: update requirement for graphRAG and lighten unnecessary packages * feat: add knowledge network index (#118) * feat: add Knowledge Network index * fix: update reader mode setting for knet * fix: update init knet * fix: update collection name to index pipeline * fix: missing req --------- Co-authored-by: jeff52415 <jeff.yang@cinnamon.is> * fix: update info panel return for graphrag * fix: retriever setting graphrag * feat: local llm settings (#122) * feat: expose context length as reasoning setting to better fit local models * fix: update context length setting for agents * fix: rework threadpool llm call * fix: fix improve indexing logic * fix: fix improve UI * feat: add lancedb * fix: improve lancedb logic * feat: add lancedb vectorstore * fix: lighten requirement * fix: improve lanceDB vs * fix: improve UI * fix: openai retry * fix: update reqs * fix: update launch command * feat: update Dockerfile * feat: add plot history * fix: update default config * fix: remove verbose print * fix: update default setting * fix: update gradio plot return * fix: default gradio tmp * fix: improve lancedb docstore * fix: fix question decompose pipeline * feat: add multimodal reader in UI * fix: udpate docs * fix: update default settings & docker build * fix: update app startup * chore: update documentation * chore: update README * chore: update README --------- Co-authored-by: trducng <trungduc1992@gmail.com> * chore: update README * chore: update README --------- Co-authored-by: trducng <trungduc1992@gmail.com> Co-authored-by: cin-ace <ace@cinnamon.is> Co-authored-by: Linh Nguyen <70562198+linhnguyen-cinnamon@users.noreply.github.com> Co-authored-by: linhnguyen-cinnamon <cinmc0019@CINMC0019-LinhNguyen.local> Co-authored-by: cin-jacky <101088014+jacky0218@users.noreply.github.com> Co-authored-by: jacky0218 <jacky0218@github.com> Co-authored-by: kan_cin <kan@cinnamon.is> Co-authored-by: phv2312 <kat87yb@gmail.com> Co-authored-by: jeff52415 <jeff.yang@cinnamon.is>
This commit is contained in:
committed by
GitHub
parent
86d60e1649
commit
2570e11501
@@ -39,16 +39,11 @@ class ReactAgent(BaseAgent):
|
||||
)
|
||||
max_iterations: int = 5
|
||||
strict_decode: bool = False
|
||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||
chunk_size=800,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
max_context_length: int = Param(
|
||||
default=3000,
|
||||
help="Max context length for each tool output.",
|
||||
)
|
||||
trim_func: TokenSplitter | None = None
|
||||
|
||||
def _compose_plugin_description(self) -> str:
|
||||
"""
|
||||
@@ -149,14 +144,28 @@ class ReactAgent(BaseAgent):
|
||||
function_map[plugin.name] = plugin
|
||||
return function_map
|
||||
|
||||
def _trim(self, text: str) -> str:
|
||||
def _trim(self, text: str | Document) -> str:
|
||||
"""
|
||||
Trim the text to the maximum token length.
|
||||
"""
|
||||
evidence_trim_func = (
|
||||
self.trim_func
|
||||
if self.trim_func
|
||||
else TokenSplitter(
|
||||
chunk_size=self.max_context_length,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
)
|
||||
if isinstance(text, str):
|
||||
texts = self.trim_func([Document(text=text)])
|
||||
texts = evidence_trim_func([Document(text=text)])
|
||||
elif isinstance(text, Document):
|
||||
texts = self.trim_func([text])
|
||||
texts = evidence_trim_func([text])
|
||||
else:
|
||||
raise ValueError("Invalid text type to trim")
|
||||
trim_text = texts[0].text
|
||||
|
@@ -39,16 +39,11 @@ class RewooAgent(BaseAgent):
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||
chunk_size=3000,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
max_context_length: int = Param(
|
||||
default=3000,
|
||||
help="Max context length for each tool output.",
|
||||
)
|
||||
trim_func: TokenSplitter | None = None
|
||||
|
||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||
def planner(self):
|
||||
@@ -248,8 +243,22 @@ class RewooAgent(BaseAgent):
|
||||
return p
|
||||
|
||||
def _trim_evidence(self, evidence: str):
|
||||
evidence_trim_func = (
|
||||
self.trim_func
|
||||
if self.trim_func
|
||||
else TokenSplitter(
|
||||
chunk_size=self.max_context_length,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
)
|
||||
if evidence:
|
||||
texts = self.trim_func([Document(text=evidence)])
|
||||
texts = evidence_trim_func([Document(text=evidence)])
|
||||
evidence = texts[0].text
|
||||
logging.info(f"len (trimmed): {len(evidence)}")
|
||||
return evidence
|
||||
@@ -317,6 +326,14 @@ class RewooAgent(BaseAgent):
|
||||
)
|
||||
|
||||
print("Planner output:", planner_text_output)
|
||||
# output planner to info panel
|
||||
yield AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status="thinking",
|
||||
intermediate_steps=[{"planner_log": planner_text_output}],
|
||||
)
|
||||
|
||||
# Work
|
||||
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||
planner_evidences, evidence_level
|
||||
@@ -326,7 +343,9 @@ class RewooAgent(BaseAgent):
|
||||
worker_log += f"{plan}: {plans[plan]}\n"
|
||||
current_progress = f"{plan}: {plans[plan]}\n"
|
||||
for e in plan_to_es[plan]:
|
||||
worker_log += f"#Action: {planner_evidences.get(e, None)}\n"
|
||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||
current_progress += f"#Action: {planner_evidences.get(e, None)}\n"
|
||||
current_progress += f"{e}: {worker_evidences[e]}\n"
|
||||
|
||||
yield AgentOutput(
|
||||
|
@@ -1,7 +1,7 @@
|
||||
from typing import AnyStr, Optional, Type
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from langchain_community.utilities import SerpAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .base import BaseTool
|
||||
|
@@ -22,12 +22,16 @@ class LLMTool(BaseTool):
|
||||
)
|
||||
llm: BaseLLM
|
||||
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
||||
dummy_mode: bool = True
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
output = None
|
||||
try:
|
||||
response = self.llm(query)
|
||||
if not self.dummy_mode:
|
||||
response = self.llm(query)
|
||||
else:
|
||||
response = None
|
||||
except ValueError:
|
||||
raise ToolException("LLM Tool call failed")
|
||||
output = response.text
|
||||
output = response.text if response else "<->"
|
||||
return output
|
||||
|
@@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
|
||||
from langchain.schema.messages import AIMessage as LCAIMessage
|
||||
from langchain.schema.messages import HumanMessage as LCHumanMessage
|
||||
from langchain.schema.messages import SystemMessage as LCSystemMessage
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.schema import Document as BaseDocument
|
||||
from llama_index.core.bridge.pydantic import Field
|
||||
from llama_index.core.schema import Document as BaseDocument
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
@@ -38,7 +38,7 @@ class Document(BaseDocument):
|
||||
|
||||
content: Any = None
|
||||
source: Optional[str] = None
|
||||
channel: Optional[Literal["chat", "info", "index", "debug"]] = None
|
||||
channel: Optional[Literal["chat", "info", "index", "debug", "plot"]] = None
|
||||
|
||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||
if content is None:
|
||||
@@ -140,6 +140,7 @@ class LLMInterface(AIMessage):
|
||||
total_cost: float = 0
|
||||
logits: list[list[float]] = Field(default_factory=list)
|
||||
messages: list[AIMessage] = Field(default_factory=list)
|
||||
logprobs: list[float] = []
|
||||
|
||||
|
||||
class ExtractorOutput(Document):
|
||||
|
@@ -133,9 +133,7 @@ def construct_chat_ui(
|
||||
label="Output file", show_label=True, height=100
|
||||
)
|
||||
export_btn = gr.Button("Export")
|
||||
export_btn.click(
|
||||
func_export_to_excel, inputs=None, outputs=exported_file
|
||||
)
|
||||
export_btn.click(func_export_to_excel, inputs=[], outputs=exported_file)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
|
@@ -91,7 +91,7 @@ def construct_pipeline_ui(
|
||||
save_btn.click(func_save, inputs=params, outputs=history_dataframe)
|
||||
load_params_btn = gr.Button("Reload params")
|
||||
load_params_btn.click(
|
||||
func_load_params, inputs=None, outputs=history_dataframe
|
||||
func_load_params, inputs=[], outputs=history_dataframe
|
||||
)
|
||||
history_dataframe.render()
|
||||
history_dataframe.select(
|
||||
@@ -103,7 +103,7 @@ def construct_pipeline_ui(
|
||||
export_btn = gr.Button(
|
||||
"Export (Result will be in Exported file next to Output)"
|
||||
)
|
||||
export_btn.click(func_export, inputs=None, outputs=exported_file)
|
||||
export_btn.click(func_export, inputs=[], outputs=exported_file)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
if params:
|
||||
|
@@ -1,5 +1,15 @@
|
||||
from itertools import islice
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import tiktoken
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_not_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
from theflow.utils.modules import import_dotted_string
|
||||
|
||||
from kotaemon.base import Param
|
||||
@@ -7,6 +17,24 @@ from kotaemon.base import Param
|
||||
from .base import BaseEmbeddings, Document, DocumentWithEmbedding
|
||||
|
||||
|
||||
def split_text_by_chunk_size(text: str, chunk_size: int) -> list[list[int]]:
|
||||
"""Split the text into chunks of a given size
|
||||
|
||||
Args:
|
||||
text: text to split
|
||||
chunk_size: size of each chunk
|
||||
|
||||
Returns:
|
||||
list of chunks (as tokens)
|
||||
"""
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
tokens = iter(encoding.encode(text))
|
||||
result = []
|
||||
while chunk := list(islice(tokens, chunk_size)):
|
||||
result.append(chunk)
|
||||
return result
|
||||
|
||||
|
||||
class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||
"""Base interface for OpenAI embedding model, using the openai library.
|
||||
|
||||
@@ -32,6 +60,9 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||
"Only supported in `text-embedding-3` and later models."
|
||||
),
|
||||
)
|
||||
context_length: Optional[int] = Param(
|
||||
None, help="The maximum context length of the embedding model"
|
||||
)
|
||||
|
||||
@Param.auto(depends_on=["max_retries"])
|
||||
def max_retries_(self):
|
||||
@@ -56,16 +87,42 @@ class BaseOpenAIEmbeddings(BaseEmbeddings):
|
||||
def invoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
) -> list[DocumentWithEmbedding]:
|
||||
input_ = self.prepare_input(text)
|
||||
input_doc = self.prepare_input(text)
|
||||
client = self.prepare_client(async_version=False)
|
||||
resp = self.openai_response(
|
||||
client, input=[_.text if _.text else " " for _ in input_], **kwargs
|
||||
).dict()
|
||||
output_ = sorted(resp["data"], key=lambda x: x["index"])
|
||||
return [
|
||||
DocumentWithEmbedding(embedding=o["embedding"], content=i)
|
||||
for i, o in zip(input_, output_)
|
||||
]
|
||||
|
||||
input_: list[str | list[int]] = []
|
||||
splitted_indices = {}
|
||||
for idx, text in enumerate(input_doc):
|
||||
if self.context_length:
|
||||
chunks = split_text_by_chunk_size(text.text or " ", self.context_length)
|
||||
splitted_indices[idx] = (len(input_), len(input_) + len(chunks))
|
||||
input_.extend(chunks)
|
||||
else:
|
||||
splitted_indices[idx] = (len(input_), len(input_) + 1)
|
||||
input_.append(text.text)
|
||||
|
||||
resp = self.openai_response(client, input=input_, **kwargs).dict()
|
||||
output_ = list(sorted(resp["data"], key=lambda x: x["index"]))
|
||||
|
||||
output = []
|
||||
for idx, doc in enumerate(input_doc):
|
||||
embs = output_[splitted_indices[idx][0] : splitted_indices[idx][1]]
|
||||
if len(embs) == 1:
|
||||
output.append(
|
||||
DocumentWithEmbedding(embedding=embs[0]["embedding"], content=doc)
|
||||
)
|
||||
continue
|
||||
|
||||
chunk_lens = [
|
||||
len(_)
|
||||
for _ in input_[splitted_indices[idx][0] : splitted_indices[idx][1]]
|
||||
]
|
||||
vs: list[list[float]] = [_["embedding"] for _ in embs]
|
||||
emb = np.average(vs, axis=0, weights=chunk_lens)
|
||||
emb = emb / np.linalg.norm(emb)
|
||||
output.append(DocumentWithEmbedding(embedding=emb.tolist(), content=doc))
|
||||
|
||||
return output
|
||||
|
||||
async def ainvoke(
|
||||
self, text: str | list[str] | Document | list[Document], *args, **kwargs
|
||||
@@ -118,6 +175,13 @@ class OpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||
|
||||
return OpenAI(**params)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_not_exception_type(
|
||||
(openai.NotFoundError, openai.BadRequestError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=40),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params: dict = {
|
||||
@@ -174,6 +238,13 @@ class AzureOpenAIEmbeddings(BaseOpenAIEmbeddings):
|
||||
|
||||
return AzureOpenAI(**params)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_not_exception_type(
|
||||
(openai.NotFoundError, openai.BadRequestError)
|
||||
),
|
||||
wait=wait_random_exponential(min=1, max=40),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def openai_response(self, client, **kwargs):
|
||||
"""Get the openai response"""
|
||||
params: dict = {
|
||||
|
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Type
|
||||
|
||||
from llama_index.node_parser.interface import NodeParser
|
||||
from llama_index.core.node_parser.interface import NodeParser
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
|
||||
@@ -32,7 +32,7 @@ class LlamaIndexDocTransformerMixin:
|
||||
Example:
|
||||
class TokenSplitter(LlamaIndexMixin, BaseSplitter):
|
||||
def _get_li_class(self):
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
return TokenTextSplitter
|
||||
|
||||
To use this mixin, please:
|
||||
|
@@ -15,7 +15,7 @@ class TitleExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||
super().__init__(llm=llm, nodes=nodes, **params)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.extractors import TitleExtractor
|
||||
from llama_index.core.extractors import TitleExtractor
|
||||
|
||||
return TitleExtractor
|
||||
|
||||
@@ -30,6 +30,6 @@ class SummaryExtractor(LlamaIndexDocTransformerMixin, BaseDocParser):
|
||||
super().__init__(llm=llm, summaries=summaries, **params)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.extractors import SummaryExtractor
|
||||
from llama_index.core.extractors import SummaryExtractor
|
||||
|
||||
return SummaryExtractor
|
||||
|
@@ -1,27 +1,42 @@
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
|
||||
from llama_index.readers import PDFReader
|
||||
from llama_index.readers.base import BaseReader
|
||||
from decouple import config
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, Param
|
||||
from kotaemon.indices.extractors import BaseDocParser
|
||||
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
|
||||
from kotaemon.loaders import (
|
||||
AdobeReader,
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
DirectoryReader,
|
||||
HtmlReader,
|
||||
MathpixPDFReader,
|
||||
MhtmlReader,
|
||||
OCRReader,
|
||||
PandasExcelReader,
|
||||
PDFThumbnailReader,
|
||||
UnstructuredReader,
|
||||
)
|
||||
|
||||
unstructured = UnstructuredReader()
|
||||
adobe_reader = AdobeReader()
|
||||
azure_reader = AzureAIDocumentIntelligenceLoader(
|
||||
endpoint=str(config("AZURE_DI_ENDPOINT", default="")),
|
||||
credential=str(config("AZURE_DI_CREDENTIAL", default="")),
|
||||
cache_dir=getattr(flowsettings, "KH_MARKDOWN_OUTPUT_DIR", None),
|
||||
)
|
||||
adobe_reader.vlm_endpoint = azure_reader.vlm_endpoint = getattr(
|
||||
flowsettings, "KH_VLM_ENDPOINT", ""
|
||||
)
|
||||
|
||||
|
||||
KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
".docx": unstructured,
|
||||
".pptx": unstructured,
|
||||
".xls": unstructured,
|
||||
".doc": unstructured,
|
||||
".html": HtmlReader(),
|
||||
@@ -31,7 +46,7 @@ KH_DEFAULT_FILE_EXTRACTORS: dict[str, BaseReader] = {
|
||||
".jpg": unstructured,
|
||||
".tiff": unstructured,
|
||||
".tif": unstructured,
|
||||
".pdf": PDFReader(),
|
||||
".pdf": PDFThumbnailReader(),
|
||||
}
|
||||
|
||||
|
||||
|
@@ -103,7 +103,9 @@ class CitationPipeline(BaseComponent):
|
||||
print("CitationPipeline: invoking LLM")
|
||||
llm_output = self.get_from_path("llm").invoke(messages, **llm_kwargs)
|
||||
print("CitationPipeline: finish invoking LLM")
|
||||
if not llm_output.messages:
|
||||
if not llm_output.messages or not llm_output.additional_kwargs.get(
|
||||
"tool_calls"
|
||||
):
|
||||
return None
|
||||
function_output = llm_output.additional_kwargs["tool_calls"][0]["function"][
|
||||
"arguments"
|
||||
|
@@ -1,5 +1,13 @@
|
||||
from .base import BaseReranking
|
||||
from .cohere import CohereReranking
|
||||
from .llm import LLMReranking
|
||||
from .llm_scoring import LLMScoring
|
||||
from .llm_trulens import LLMTrulensScoring
|
||||
|
||||
__all__ = ["CohereReranking", "LLMReranking", "BaseReranking"]
|
||||
__all__ = [
|
||||
"CohereReranking",
|
||||
"LLMReranking",
|
||||
"LLMScoring",
|
||||
"BaseReranking",
|
||||
"LLMTrulensScoring",
|
||||
]
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from decouple import config
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
@@ -9,8 +9,7 @@ from .base import BaseReranking
|
||||
|
||||
class CohereReranking(BaseReranking):
|
||||
model_name: str = "rerank-multilingual-v2.0"
|
||||
cohere_api_key: str = os.environ.get("COHERE_API_KEY", "")
|
||||
top_k: int = 1
|
||||
cohere_api_key: str = config("COHERE_API_KEY", "")
|
||||
|
||||
def run(self, documents: list[Document], query: str) -> list[Document]:
|
||||
"""Use Cohere Reranker model to re-order documents
|
||||
@@ -22,6 +21,10 @@ class CohereReranking(BaseReranking):
|
||||
"Please install Cohere " "`pip install cohere` to use Cohere Reranking"
|
||||
)
|
||||
|
||||
if not self.cohere_api_key:
|
||||
print("Cohere API key not found. Skipping reranking.")
|
||||
return documents
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
@@ -29,12 +32,13 @@ class CohereReranking(BaseReranking):
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
response = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs
|
||||
)
|
||||
for r in results:
|
||||
print("Cohere score", [r.relevance_score for r in response.results])
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
doc.metadata["cohere_reranking_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
54
libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py
Normal file
54
libs/kotaemon/kotaemon/indices/rankings/llm_scoring.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import numpy as np
|
||||
from langchain.output_parsers.boolean import BooleanOutputParser
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .llm import LLMReranking
|
||||
|
||||
|
||||
class LLMScoring(LLMReranking):
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
query: str,
|
||||
) -> list[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs: list[Document] = []
|
||||
output_parser = BooleanOutputParser()
|
||||
|
||||
if self.concurrent:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
futures.append(executor.submit(lambda: self.llm(_prompt)))
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
else:
|
||||
results = []
|
||||
for doc in documents:
|
||||
_prompt = self.prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
results.append(self.llm(_prompt))
|
||||
|
||||
for result, doc in zip(results, documents):
|
||||
score = np.exp(np.average(result.logprobs))
|
||||
include_doc = output_parser.parse(result.text)
|
||||
if include_doc:
|
||||
doc.metadata["llm_reranking_score"] = score
|
||||
else:
|
||||
doc.metadata["llm_reranking_score"] = 1 - score
|
||||
filtered_docs.append(doc)
|
||||
|
||||
# prevent returning empty result
|
||||
if len(filtered_docs) == 0:
|
||||
filtered_docs = documents[: self.top_k]
|
||||
|
||||
return filtered_docs
|
182
libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
Normal file
182
libs/kotaemon/kotaemon/indices/rankings/llm_trulens.py
Normal file
@@ -0,0 +1,182 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
|
||||
import tiktoken
|
||||
|
||||
from kotaemon.base import Document, HumanMessage, SystemMessage
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .llm import LLMReranking
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = PromptTemplate(
|
||||
"""You are a RELEVANCE grader; providing the relevance of the given CONTEXT to the given QUESTION.
|
||||
Respond only as a number from 0 to 10 where 0 is the least relevant and 10 is the most relevant.
|
||||
|
||||
A few additional scoring guidelines:
|
||||
|
||||
- Long CONTEXTS should score equally well as short CONTEXTS.
|
||||
|
||||
- RELEVANCE score should increase as the CONTEXTS provides more RELEVANT context to the QUESTION.
|
||||
|
||||
- RELEVANCE score should increase as the CONTEXTS provides RELEVANT context to more parts of the QUESTION.
|
||||
|
||||
- CONTEXT that is RELEVANT to some of the QUESTION should score of 2, 3 or 4. Higher score indicates more RELEVANCE.
|
||||
|
||||
- CONTEXT that is RELEVANT to most of the QUESTION should get a score of 5, 6, 7 or 8. Higher score indicates more RELEVANCE.
|
||||
|
||||
- CONTEXT that is RELEVANT to the entire QUESTION should get a score of 9 or 10. Higher score indicates more RELEVANCE.
|
||||
|
||||
- CONTEXT must be relevant and helpful for answering the entire QUESTION to get a score of 10.
|
||||
|
||||
- Never elaborate.""" # noqa: E501
|
||||
)
|
||||
|
||||
USER_PROMPT_TEMPLATE = PromptTemplate(
|
||||
"""QUESTION: {question}
|
||||
|
||||
CONTEXT: {context}
|
||||
|
||||
RELEVANCE: """
|
||||
) # noqa
|
||||
|
||||
PATTERN_INTEGER: re.Pattern = re.compile(r"([+-]?[1-9][0-9]*|0)")
|
||||
"""Regex that matches integers."""
|
||||
|
||||
MAX_CONTEXT_LEN = 7500
|
||||
|
||||
|
||||
def validate_rating(rating) -> int:
|
||||
"""Validate a rating is between 0 and 10."""
|
||||
|
||||
if not 0 <= rating <= 10:
|
||||
raise ValueError("Rating must be between 0 and 10")
|
||||
|
||||
return rating
|
||||
|
||||
|
||||
def re_0_10_rating(s: str) -> int:
|
||||
"""Extract a 0-10 rating from a string.
|
||||
|
||||
If the string does not match an integer or matches an integer outside the
|
||||
0-10 range, raises an error instead. If multiple numbers are found within
|
||||
the expected 0-10 range, the smallest is returned.
|
||||
|
||||
Args:
|
||||
s: String to extract rating from.
|
||||
|
||||
Returns:
|
||||
int: Extracted rating.
|
||||
|
||||
Raises:
|
||||
ParseError: If no integers between 0 and 10 are found in the string.
|
||||
"""
|
||||
|
||||
matches = PATTERN_INTEGER.findall(s)
|
||||
if not matches:
|
||||
raise AssertionError
|
||||
|
||||
vals = set()
|
||||
for match in matches:
|
||||
try:
|
||||
vals.add(validate_rating(int(match)))
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if not vals:
|
||||
raise AssertionError
|
||||
|
||||
# Min to handle cases like "The rating is 8 out of 10."
|
||||
return min(vals)
|
||||
|
||||
|
||||
class LLMTrulensScoring(LLMReranking):
|
||||
llm: BaseLLM
|
||||
system_prompt_template: PromptTemplate = SYSTEM_PROMPT_TEMPLATE
|
||||
user_prompt_template: PromptTemplate = USER_PROMPT_TEMPLATE
|
||||
concurrent: bool = True
|
||||
normalize: float = 10
|
||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||
chunk_size=MAX_CONTEXT_LEN,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: list[Document],
|
||||
query: str,
|
||||
) -> list[Document]:
|
||||
"""Filter down documents based on their relevance to the query."""
|
||||
filtered_docs = []
|
||||
|
||||
documents = sorted(documents, key=lambda doc: doc.get_content())
|
||||
if self.concurrent:
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = []
|
||||
for doc in documents:
|
||||
chunked_doc_content = self.trim_func(
|
||||
[
|
||||
Document(content=doc.get_content())
|
||||
# skip metadata which cause troubles
|
||||
]
|
||||
)[0].text
|
||||
|
||||
messages = []
|
||||
messages.append(
|
||||
SystemMessage(self.system_prompt_template.populate())
|
||||
)
|
||||
messages.append(
|
||||
HumanMessage(
|
||||
self.user_prompt_template.populate(
|
||||
question=query, context=chunked_doc_content
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def llm_call():
|
||||
return self.llm(messages).text
|
||||
|
||||
futures.append(executor.submit(llm_call))
|
||||
|
||||
results = [future.result() for future in futures]
|
||||
else:
|
||||
results = []
|
||||
for doc in documents:
|
||||
messages = []
|
||||
messages.append(SystemMessage(self.system_prompt_template.populate()))
|
||||
messages.append(
|
||||
SystemMessage(
|
||||
self.user_prompt_template.populate(
|
||||
question=query, context=doc.get_content()
|
||||
)
|
||||
)
|
||||
)
|
||||
results.append(self.llm(messages).text)
|
||||
|
||||
# use Boolean parser to extract relevancy output from LLM
|
||||
results = [
|
||||
(r_idx, float(re_0_10_rating(result)) / self.normalize)
|
||||
for r_idx, result in enumerate(results)
|
||||
]
|
||||
results.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
for r_idx, score in results:
|
||||
doc = documents[r_idx]
|
||||
doc.metadata["llm_trulens_score"] = score
|
||||
filtered_docs.append(doc)
|
||||
|
||||
print(
|
||||
"LLM rerank scores",
|
||||
[doc.metadata["llm_trulens_score"] for doc in filtered_docs],
|
||||
)
|
||||
|
||||
return filtered_docs
|
@@ -23,7 +23,7 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
from llama_index.core.text_splitter import TokenTextSplitter
|
||||
|
||||
return TokenTextSplitter
|
||||
|
||||
@@ -44,6 +44,6 @@ class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.node_parser import SentenceWindowNodeParser
|
||||
from llama_index.core.node_parser import SentenceWindowNodeParser
|
||||
|
||||
return SentenceWindowNodeParser
|
||||
|
@@ -1,14 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence, cast
|
||||
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.embeddings import BaseEmbeddings
|
||||
from kotaemon.storages import BaseDocumentStore, BaseVectorStore
|
||||
|
||||
from .base import BaseIndexing, BaseRetrieval
|
||||
from .rankings import BaseReranking
|
||||
from .rankings import BaseReranking, LLMReranking
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
@@ -23,9 +27,11 @@ class VectorIndexing(BaseIndexing):
|
||||
- List of texts
|
||||
"""
|
||||
|
||||
cache_dir: Optional[str] = getattr(flowsettings, "KH_CHUNKS_OUTPUT_DIR", None)
|
||||
vector_store: BaseVectorStore
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
count_: int = 0
|
||||
|
||||
def to_retrieval_pipeline(self, *args, **kwargs):
|
||||
"""Convert the indexing pipeline to a retrieval pipeline"""
|
||||
@@ -44,6 +50,52 @@ class VectorIndexing(BaseIndexing):
|
||||
qa_pipeline=CitationQAPipeline(**kwargs),
|
||||
)
|
||||
|
||||
def write_chunk_to_file(self, docs: list[Document]):
|
||||
# save the chunks content into markdown format
|
||||
if self.cache_dir:
|
||||
file_name = Path(docs[0].metadata["file_name"])
|
||||
for i in range(len(docs)):
|
||||
markdown_content = ""
|
||||
if "page_label" in docs[i].metadata:
|
||||
page_label = str(docs[i].metadata["page_label"])
|
||||
markdown_content += f"Page label: {page_label}"
|
||||
if "file_name" in docs[i].metadata:
|
||||
filename = docs[i].metadata["file_name"]
|
||||
markdown_content += f"\nFile name: {filename}"
|
||||
if "section" in docs[i].metadata:
|
||||
section = docs[i].metadata["section"]
|
||||
markdown_content += f"\nSection: {section}"
|
||||
if "type" in docs[i].metadata:
|
||||
if docs[i].metadata["type"] == "image":
|
||||
image_origin = docs[i].metadata["image_origin"]
|
||||
image_origin = f'<p><img src="{image_origin}"></p>'
|
||||
markdown_content += f"\nImage origin: {image_origin}"
|
||||
if docs[i].text:
|
||||
markdown_content += f"\ntext:\n{docs[i].text}"
|
||||
|
||||
with open(
|
||||
Path(self.cache_dir) / f"{file_name.stem}_{self.count_+i}.md",
|
||||
"w",
|
||||
encoding="utf-8",
|
||||
) as f:
|
||||
f.write(markdown_content)
|
||||
|
||||
def add_to_docstore(self, docs: list[Document]):
|
||||
if self.doc_store:
|
||||
print("Adding documents to doc store")
|
||||
self.doc_store.add(docs)
|
||||
|
||||
def add_to_vectorstore(self, docs: list[Document]):
|
||||
# in case we want to skip embedding
|
||||
if self.vector_store:
|
||||
print(f"Getting embeddings for {len(docs)} nodes")
|
||||
embeddings = self.embedding(docs)
|
||||
print("Adding embeddings to vector store")
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.doc_id for t in docs],
|
||||
)
|
||||
|
||||
def run(self, text: str | list[str] | Document | list[Document]):
|
||||
input_: list[Document] = []
|
||||
if not isinstance(text, list):
|
||||
@@ -59,16 +111,10 @@ class VectorIndexing(BaseIndexing):
|
||||
f"Invalid input type {type(item)}, should be str or Document"
|
||||
)
|
||||
|
||||
print(f"Getting embeddings for {len(input_)} nodes")
|
||||
embeddings = self.embedding(input_)
|
||||
print("Adding embeddings to vector store")
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.doc_id for t in input_],
|
||||
)
|
||||
if self.doc_store:
|
||||
print("Adding documents to doc store")
|
||||
self.doc_store.add(input_)
|
||||
self.add_to_vectorstore(input_)
|
||||
self.add_to_docstore(input_)
|
||||
self.write_chunk_to_file(input_)
|
||||
self.count_ += len(input_)
|
||||
|
||||
|
||||
class VectorRetrieval(BaseRetrieval):
|
||||
@@ -78,7 +124,16 @@ class VectorRetrieval(BaseRetrieval):
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
embedding: BaseEmbeddings
|
||||
rerankers: Sequence[BaseReranking] = []
|
||||
top_k: int = 1
|
||||
top_k: int = 5
|
||||
first_round_top_k_mult: int = 10
|
||||
retrieval_mode: str = "hybrid" # vector, text, hybrid
|
||||
|
||||
def _filter_docs(
|
||||
self, documents: list[RetrievedDocument], top_k: int | None = None
|
||||
):
|
||||
if top_k:
|
||||
documents = documents[:top_k]
|
||||
return documents
|
||||
|
||||
def run(
|
||||
self, text: str | Document, top_k: Optional[int] = None, **kwargs
|
||||
@@ -95,24 +150,155 @@ class VectorRetrieval(BaseRetrieval):
|
||||
if top_k is None:
|
||||
top_k = self.top_k
|
||||
|
||||
do_extend = kwargs.pop("do_extend", False)
|
||||
thumbnail_count = kwargs.pop("thumbnail_count", 3)
|
||||
|
||||
if do_extend:
|
||||
top_k_first_round = top_k * self.first_round_top_k_mult
|
||||
else:
|
||||
top_k_first_round = top_k
|
||||
|
||||
if self.doc_store is None:
|
||||
raise ValueError(
|
||||
"doc_store is not provided. Please provide a doc_store to "
|
||||
"retrieve the documents"
|
||||
)
|
||||
|
||||
emb: list[float] = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
for doc, score in zip(docs, scores)
|
||||
]
|
||||
result: list[RetrievedDocument] = []
|
||||
# TODO: should declare scope directly in the run params
|
||||
scope = kwargs.pop("scope", None)
|
||||
emb: list[float]
|
||||
|
||||
if self.retrieval_mode == "vector":
|
||||
emb = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(
|
||||
embedding=emb, top_k=top_k_first_round, **kwargs
|
||||
)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
for doc, score in zip(docs, scores)
|
||||
]
|
||||
elif self.retrieval_mode == "text":
|
||||
query = text.text if isinstance(text, Document) else text
|
||||
docs = self.doc_store.query(query, top_k=top_k_first_round, doc_ids=scope)
|
||||
result = [RetrievedDocument(**doc.to_dict(), score=-1.0) for doc in docs]
|
||||
elif self.retrieval_mode == "hybrid":
|
||||
# similarity search section
|
||||
emb = self.embedding(text)[0].embedding
|
||||
vs_docs: list[RetrievedDocument] = []
|
||||
vs_ids: list[str] = []
|
||||
vs_scores: list[float] = []
|
||||
|
||||
def query_vectorstore():
|
||||
nonlocal vs_docs
|
||||
nonlocal vs_scores
|
||||
nonlocal vs_ids
|
||||
|
||||
assert self.doc_store is not None
|
||||
_, vs_scores, vs_ids = self.vector_store.query(
|
||||
embedding=emb, top_k=top_k_first_round, **kwargs
|
||||
)
|
||||
if vs_ids:
|
||||
vs_docs = self.doc_store.get(vs_ids)
|
||||
|
||||
# full-text search section
|
||||
ds_docs: list[RetrievedDocument] = []
|
||||
|
||||
def query_docstore():
|
||||
nonlocal ds_docs
|
||||
|
||||
assert self.doc_store is not None
|
||||
query = text.text if isinstance(text, Document) else text
|
||||
ds_docs = self.doc_store.query(
|
||||
query, top_k=top_k_first_round, doc_ids=scope
|
||||
)
|
||||
|
||||
vs_query_thread = threading.Thread(target=query_vectorstore)
|
||||
ds_query_thread = threading.Thread(target=query_docstore)
|
||||
|
||||
vs_query_thread.start()
|
||||
ds_query_thread.start()
|
||||
|
||||
vs_query_thread.join()
|
||||
ds_query_thread.join()
|
||||
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=-1.0)
|
||||
for doc in ds_docs
|
||||
if doc not in vs_ids
|
||||
]
|
||||
result += [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
for doc, score in zip(vs_docs, vs_scores)
|
||||
]
|
||||
print(f"Got {len(vs_docs)} from vectorstore")
|
||||
print(f"Got {len(ds_docs)} from docstore")
|
||||
|
||||
# use additional reranker to re-order the document list
|
||||
if self.rerankers:
|
||||
if self.rerankers and text:
|
||||
for reranker in self.rerankers:
|
||||
# if reranker is LLMReranking, limit the document with top_k items only
|
||||
if isinstance(reranker, LLMReranking):
|
||||
result = self._filter_docs(result, top_k=top_k)
|
||||
result = reranker(documents=result, query=text)
|
||||
|
||||
result = self._filter_docs(result, top_k=top_k)
|
||||
print(f"Got raw {len(result)} retrieved documents")
|
||||
|
||||
# add page thumbnails to the result if exists
|
||||
thumbnail_doc_ids: set[str] = set()
|
||||
# we should copy the text from retrieved text chunk
|
||||
# to the thumbnail to get relevant LLM score correctly
|
||||
text_thumbnail_docs: dict[str, RetrievedDocument] = {}
|
||||
|
||||
non_thumbnail_docs = []
|
||||
raw_thumbnail_docs = []
|
||||
for doc in result:
|
||||
if doc.metadata.get("type") == "thumbnail":
|
||||
# change type to image to display on UI
|
||||
doc.metadata["type"] = "image"
|
||||
raw_thumbnail_docs.append(doc)
|
||||
continue
|
||||
if (
|
||||
"thumbnail_doc_id" in doc.metadata
|
||||
and len(thumbnail_doc_ids) < thumbnail_count
|
||||
):
|
||||
thumbnail_id = doc.metadata["thumbnail_doc_id"]
|
||||
thumbnail_doc_ids.add(thumbnail_id)
|
||||
text_thumbnail_docs[thumbnail_id] = doc
|
||||
else:
|
||||
non_thumbnail_docs.append(doc)
|
||||
|
||||
linked_thumbnail_docs = self.doc_store.get(list(thumbnail_doc_ids))
|
||||
print(
|
||||
"thumbnail docs",
|
||||
len(linked_thumbnail_docs),
|
||||
"non-thumbnail docs",
|
||||
len(non_thumbnail_docs),
|
||||
"raw-thumbnail docs",
|
||||
len(raw_thumbnail_docs),
|
||||
)
|
||||
additional_docs = []
|
||||
|
||||
for thumbnail_doc in linked_thumbnail_docs:
|
||||
text_doc = text_thumbnail_docs[thumbnail_doc.doc_id]
|
||||
doc_dict = thumbnail_doc.to_dict()
|
||||
doc_dict["_id"] = text_doc.doc_id
|
||||
doc_dict["content"] = text_doc.content
|
||||
doc_dict["metadata"]["type"] = "image"
|
||||
for key in text_doc.metadata:
|
||||
if key not in doc_dict["metadata"]:
|
||||
doc_dict["metadata"][key] = text_doc.metadata[key]
|
||||
|
||||
additional_docs.append(RetrievedDocument(**doc_dict, score=text_doc.score))
|
||||
|
||||
result = additional_docs + non_thumbnail_docs
|
||||
|
||||
if not result:
|
||||
# return output from raw retrieved thumbnails
|
||||
result = self._filter_docs(raw_thumbnail_docs, top_k=thumbnail_count)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
@@ -7,6 +7,7 @@ from .chats import (
|
||||
ChatLLM,
|
||||
ChatOpenAI,
|
||||
EndpointChatLLM,
|
||||
LCAnthropicChat,
|
||||
LCAzureChatOpenAI,
|
||||
LCChatOpenAI,
|
||||
LlamaCppChat,
|
||||
@@ -27,6 +28,7 @@ __all__ = [
|
||||
"SystemMessage",
|
||||
"AzureChatOpenAI",
|
||||
"ChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatOpenAI",
|
||||
"LlamaCppChat",
|
||||
|
@@ -1,6 +1,11 @@
|
||||
from .base import ChatLLM
|
||||
from .endpoint_based import EndpointChatLLM
|
||||
from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
|
||||
from .langchain_based import (
|
||||
LCAnthropicChat,
|
||||
LCAzureChatOpenAI,
|
||||
LCChatMixin,
|
||||
LCChatOpenAI,
|
||||
)
|
||||
from .llamacpp import LlamaCppChat
|
||||
from .openai import AzureChatOpenAI, ChatOpenAI
|
||||
|
||||
@@ -10,6 +15,7 @@ __all__ = [
|
||||
"ChatLLM",
|
||||
"EndpointChatLLM",
|
||||
"ChatOpenAI",
|
||||
"LCAnthropicChat",
|
||||
"LCChatOpenAI",
|
||||
"LCAzureChatOpenAI",
|
||||
"LCChatMixin",
|
||||
|
@@ -221,3 +221,27 @@ class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
return AzureChatOpenAI
|
||||
|
||||
|
||||
class LCAnthropicChat(LCChatMixin, ChatLLM): # type: ignore
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
model_name: str | None = None,
|
||||
temperature: float = 0.7,
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
api_key=api_key,
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
try:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
except ImportError:
|
||||
raise ImportError("Please install langchain-anthropic")
|
||||
|
||||
return ChatAnthropic
|
||||
|
@@ -159,6 +159,15 @@ class BaseChatOpenAI(ChatLLM):
|
||||
additional_kwargs["tool_calls"] = resp["choices"][0]["message"][
|
||||
"tool_calls"
|
||||
]
|
||||
|
||||
if resp["choices"][0].get("logprobs") is None:
|
||||
logprobs = []
|
||||
else:
|
||||
all_logprobs = resp["choices"][0]["logprobs"].get("content")
|
||||
logprobs = (
|
||||
[logprob["logprob"] for logprob in all_logprobs] if all_logprobs else []
|
||||
)
|
||||
|
||||
output = LLMInterface(
|
||||
candidates=[(_["message"]["content"] or "") for _ in resp["choices"]],
|
||||
content=resp["choices"][0]["message"]["content"] or "",
|
||||
@@ -170,6 +179,7 @@ class BaseChatOpenAI(ChatLLM):
|
||||
AIMessage(content=(_["message"]["content"]) or "")
|
||||
for _ in resp["choices"]
|
||||
],
|
||||
logprobs=logprobs,
|
||||
)
|
||||
|
||||
return output
|
||||
@@ -216,11 +226,24 @@ class BaseChatOpenAI(ChatLLM):
|
||||
client, messages=input_messages, stream=True, **kwargs
|
||||
)
|
||||
|
||||
for chunk in resp:
|
||||
if not chunk.choices:
|
||||
for c in resp:
|
||||
chunk = c.dict()
|
||||
if not chunk["choices"]:
|
||||
continue
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
yield LLMInterface(content=chunk.choices[0].delta.content)
|
||||
if chunk["choices"][0]["delta"]["content"] is not None:
|
||||
if chunk["choices"][0].get("logprobs") is None:
|
||||
logprobs = []
|
||||
else:
|
||||
logprobs = [
|
||||
logprob["logprob"]
|
||||
for logprob in chunk["choices"][0]["logprobs"].get(
|
||||
"content", []
|
||||
)
|
||||
]
|
||||
|
||||
yield LLMInterface(
|
||||
content=chunk["choices"][0]["delta"]["content"], logprobs=logprobs
|
||||
)
|
||||
|
||||
async def astream(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
|
||||
|
@@ -3,10 +3,12 @@ from .azureai_document_intelligence_loader import AzureAIDocumentIntelligenceLoa
|
||||
from .base import AutoReader, BaseReader
|
||||
from .composite_loader import DirectoryReader
|
||||
from .docx_loader import DocxReader
|
||||
from .excel_loader import PandasExcelReader
|
||||
from .excel_loader import ExcelReader, PandasExcelReader
|
||||
from .html_loader import HtmlReader, MhtmlReader
|
||||
from .mathpix_loader import MathpixPDFReader
|
||||
from .ocr_loader import ImageReader, OCRReader
|
||||
from .pdf_loader import PDFThumbnailReader
|
||||
from .txt_loader import TxtReader
|
||||
from .unstructured_loader import UnstructuredReader
|
||||
|
||||
__all__ = [
|
||||
@@ -14,6 +16,7 @@ __all__ = [
|
||||
"AzureAIDocumentIntelligenceLoader",
|
||||
"BaseReader",
|
||||
"PandasExcelReader",
|
||||
"ExcelReader",
|
||||
"MathpixPDFReader",
|
||||
"ImageReader",
|
||||
"OCRReader",
|
||||
@@ -23,4 +26,6 @@ __all__ = [
|
||||
"HtmlReader",
|
||||
"MhtmlReader",
|
||||
"AdobeReader",
|
||||
"TxtReader",
|
||||
"PDFThumbnailReader",
|
||||
]
|
||||
|
@@ -6,7 +6,7 @@ from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from decouple import config
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
@@ -154,7 +154,7 @@ class AdobeReader(BaseReader):
|
||||
for page_number, table_content, table_caption in tables:
|
||||
documents.append(
|
||||
Document(
|
||||
text=table_caption,
|
||||
text=table_content,
|
||||
metadata={
|
||||
"table_origin": table_content,
|
||||
"type": "table",
|
||||
|
@@ -1,10 +1,56 @@
|
||||
import base64
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from kotaemon.base import Document, Param
|
||||
|
||||
from .base import BaseReader
|
||||
from .utils.adobe import generate_single_figure_caption
|
||||
|
||||
|
||||
def crop_image(file_path: Path, bbox: list[float], page_number: int = 0) -> Image.Image:
|
||||
"""Crop the image based on the bounding box
|
||||
|
||||
Args:
|
||||
file_path (Path): path to the image file
|
||||
bbox (list[float]): bounding box of the image (in percentage [x0, y0, x1, y1])
|
||||
page_number (int, optional): page number of the image. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
Image.Image: cropped image
|
||||
"""
|
||||
left, upper, right, lower = bbox
|
||||
|
||||
img: Image.Image
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix == ".pdf":
|
||||
try:
|
||||
import fitz
|
||||
except ImportError:
|
||||
raise ImportError("Please install PyMuPDF: 'pip install PyMuPDF'")
|
||||
|
||||
doc = fitz.open(file_path)
|
||||
page = doc.load_page(page_number)
|
||||
pm = page.get_pixmap(dpi=150)
|
||||
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
|
||||
elif suffix in [".tif", ".tiff"]:
|
||||
img = Image.open(file_path)
|
||||
img.seek(page_number)
|
||||
else:
|
||||
img = Image.open(file_path)
|
||||
|
||||
return img.crop(
|
||||
(
|
||||
int(left * img.width),
|
||||
int(upper * img.height),
|
||||
int(right * img.width),
|
||||
int(lower * img.height),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||
@@ -14,7 +60,7 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||
heif, docx, xlsx, pptx and html.
|
||||
"""
|
||||
|
||||
_dependencies = ["azure-ai-documentintelligence"]
|
||||
_dependencies = ["azure-ai-documentintelligence", "PyMuPDF", "Pillow"]
|
||||
|
||||
endpoint: str = Param(
|
||||
os.environ.get("AZUREAI_DOCUMENT_INTELLIGENT_ENDPOINT", None),
|
||||
@@ -34,6 +80,29 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||
"#model-analysis-features)"
|
||||
),
|
||||
)
|
||||
output_content_format: str = Param(
|
||||
"markdown",
|
||||
help="Output content format. Can be 'markdown' or 'text'.Default is markdown",
|
||||
)
|
||||
vlm_endpoint: str = Param(
|
||||
help=(
|
||||
"Default VLM endpoint for figure captioning. If not provided, will not "
|
||||
"caption the figures"
|
||||
)
|
||||
)
|
||||
figure_friendly_filetypes: list[str] = Param(
|
||||
[".pdf", ".jpeg", ".jpg", ".png", ".bmp", ".tiff", ".heif", ".tif"],
|
||||
help=(
|
||||
"File types that we can reliably open and extract figures. "
|
||||
"For files like .docx or .html, the visual layout may be different "
|
||||
"when viewed from different tools, hence we cannot use Azure DI "
|
||||
"location to extract figures."
|
||||
),
|
||||
)
|
||||
cache_dir: str = Param(
|
||||
None,
|
||||
help="Directory to cache the downloaded files. Default is None",
|
||||
)
|
||||
|
||||
@Param.auto(depends_on=["endpoint", "credential"])
|
||||
def client_(self):
|
||||
@@ -55,14 +124,114 @@ class AzureAIDocumentIntelligenceLoader(BaseReader):
|
||||
def load_data(
|
||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> list[Document]:
|
||||
"""Extract the input file, allowing multi-modal extraction"""
|
||||
metadata = extra_info or {}
|
||||
file_name = Path(file_path)
|
||||
with open(file_path, "rb") as fi:
|
||||
poller = self.client_.begin_analyze_document(
|
||||
self.model,
|
||||
analyze_request=fi,
|
||||
content_type="application/octet-stream",
|
||||
output_content_format="markdown",
|
||||
output_content_format=self.output_content_format,
|
||||
)
|
||||
result = poller.result()
|
||||
|
||||
return [Document(content=result.content, metadata=metadata)]
|
||||
# the total text content of the document in `output_content_format` format
|
||||
text_content = result.content
|
||||
removed_spans: list[dict] = []
|
||||
|
||||
# extract the figures
|
||||
figures = []
|
||||
for figure_desc in result.get("figures", []):
|
||||
if not self.vlm_endpoint:
|
||||
continue
|
||||
if file_path.suffix.lower() not in self.figure_friendly_filetypes:
|
||||
continue
|
||||
|
||||
# read & crop the image
|
||||
page_number = figure_desc["boundingRegions"][0]["pageNumber"]
|
||||
page_width = result.pages[page_number - 1]["width"]
|
||||
page_height = result.pages[page_number - 1]["height"]
|
||||
polygon = figure_desc["boundingRegions"][0]["polygon"]
|
||||
xs = [polygon[i] for i in range(0, len(polygon), 2)]
|
||||
ys = [polygon[i] for i in range(1, len(polygon), 2)]
|
||||
bbox = [
|
||||
min(xs) / page_width,
|
||||
min(ys) / page_height,
|
||||
max(xs) / page_width,
|
||||
max(ys) / page_height,
|
||||
]
|
||||
img = crop_image(file_path, bbox, page_number - 1)
|
||||
|
||||
# convert the image into base64
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format="PNG")
|
||||
img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
|
||||
img_base64 = f"data:image/png;base64,{img_base64}"
|
||||
|
||||
# caption the image
|
||||
caption = generate_single_figure_caption(
|
||||
figure=img_base64, vlm_endpoint=self.vlm_endpoint
|
||||
)
|
||||
|
||||
# store the image into document
|
||||
figure_metadata = {
|
||||
"image_origin": img_base64,
|
||||
"type": "image",
|
||||
"page_label": page_number,
|
||||
}
|
||||
figure_metadata.update(metadata)
|
||||
|
||||
figures.append(
|
||||
Document(
|
||||
text=caption,
|
||||
metadata=figure_metadata,
|
||||
)
|
||||
)
|
||||
removed_spans += figure_desc["spans"]
|
||||
|
||||
# extract the tables
|
||||
tables = []
|
||||
for table_desc in result.get("tables", []):
|
||||
if not table_desc["spans"]:
|
||||
continue
|
||||
|
||||
# convert the tables into markdown format
|
||||
boundingRegions = table_desc["boundingRegions"]
|
||||
if boundingRegions:
|
||||
page_number = boundingRegions[0]["pageNumber"]
|
||||
else:
|
||||
page_number = 1
|
||||
|
||||
# store the tables into document
|
||||
offset = table_desc["spans"][0]["offset"]
|
||||
length = table_desc["spans"][0]["length"]
|
||||
table_metadata = {
|
||||
"type": "table",
|
||||
"page_label": page_number,
|
||||
"table_origin": text_content[offset : offset + length],
|
||||
}
|
||||
table_metadata.update(metadata)
|
||||
|
||||
tables.append(
|
||||
Document(
|
||||
text=text_content[offset : offset + length],
|
||||
metadata=table_metadata,
|
||||
)
|
||||
)
|
||||
removed_spans += table_desc["spans"]
|
||||
# save the text content into markdown format
|
||||
if self.cache_dir is not None:
|
||||
with open(
|
||||
Path(self.cache_dir) / f"{file_name.stem}.md", "w", encoding="utf-8"
|
||||
) as f:
|
||||
f.write(text_content)
|
||||
|
||||
removed_spans = sorted(removed_spans, key=lambda x: x["offset"], reverse=True)
|
||||
for span in removed_spans:
|
||||
text_content = (
|
||||
text_content[: span["offset"]]
|
||||
+ text_content[span["offset"] + span["length"] :]
|
||||
)
|
||||
|
||||
return [Document(content=text_content, metadata=metadata)] + figures + tables
|
||||
|
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Any, List, Type, Union
|
||||
from kotaemon.base import BaseComponent, Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from llama_index.readers.base import BaseReader as LIBaseReader
|
||||
from llama_index.core.readers.base import BaseReader as LIBaseReader
|
||||
|
||||
|
||||
class BaseReader(BaseComponent):
|
||||
@@ -20,7 +20,7 @@ class AutoReader(BaseReader):
|
||||
"""Init reader using string identifier or class name from llama-hub"""
|
||||
|
||||
if isinstance(reader_type, str):
|
||||
from llama_index import download_loader
|
||||
from llama_index.core import download_loader
|
||||
|
||||
self._reader = download_loader(reader_type)()
|
||||
else:
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Callable, List, Optional, Type
|
||||
|
||||
from llama_index.readers.base import BaseReader as LIBaseReader
|
||||
from llama_index.core.readers.base import BaseReader as LIBaseReader
|
||||
|
||||
from .base import BaseReader, LIReaderMixin
|
||||
|
||||
@@ -48,6 +48,6 @@ class DirectoryReader(LIReaderMixin, BaseReader):
|
||||
file_metadata: Optional[Callable[[str], dict]] = None
|
||||
|
||||
def _get_wrapped_class(self) -> Type["LIBaseReader"]:
|
||||
from llama_index import SimpleDirectoryReader
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
return SimpleDirectoryReader
|
||||
|
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas as pd
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
@@ -27,6 +27,21 @@ class DocxReader(BaseReader):
|
||||
"Please install it using `pip install python-docx`"
|
||||
)
|
||||
|
||||
def _load_single_table(self, table) -> List[List[str]]:
|
||||
"""Extract content from tables. Return a list of columns: list[str]
|
||||
Some merged cells will share duplicated content.
|
||||
"""
|
||||
n_row = len(table.rows)
|
||||
n_col = len(table.columns)
|
||||
|
||||
arrays = [["" for _ in range(n_row)] for _ in range(n_col)]
|
||||
|
||||
for i, row in enumerate(table.rows):
|
||||
for j, cell in enumerate(row.cells):
|
||||
arrays[j][i] = cell.text
|
||||
|
||||
return arrays
|
||||
|
||||
def load_data(
|
||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> List[Document]:
|
||||
@@ -50,13 +65,9 @@ class DocxReader(BaseReader):
|
||||
|
||||
tables = []
|
||||
for t in doc.tables:
|
||||
arrays = [
|
||||
[
|
||||
unicodedata.normalize("NFKC", t.cell(i, j).text)
|
||||
for i in range(len(t.rows))
|
||||
]
|
||||
for j in range(len(t.columns))
|
||||
]
|
||||
# return list of columns: list of string
|
||||
arrays = self._load_single_table(t)
|
||||
|
||||
tables.append(pd.DataFrame({a[0]: a[1:] for a in arrays}))
|
||||
|
||||
extra_info = extra_info or {}
|
||||
|
@@ -6,7 +6,7 @@ Pandas parser for .xlsx files.
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
@@ -82,6 +82,9 @@ class PandasExcelReader(BaseReader):
|
||||
sheet = []
|
||||
if include_sheetname:
|
||||
sheet.append([key])
|
||||
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||
dfs[key].fillna("", inplace=True)
|
||||
sheet.extend(dfs[key].values.astype(str).tolist())
|
||||
df_sheets.append(sheet)
|
||||
|
||||
@@ -99,3 +102,91 @@ class PandasExcelReader(BaseReader):
|
||||
]
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class ExcelReader(BaseReader):
|
||||
r"""Spreadsheet exporter respecting multiple worksheets
|
||||
|
||||
Parses CSVs using the separator detection from Pandas `read_csv` function.
|
||||
If special parameters are required, use the `pandas_config` dict.
|
||||
|
||||
Args:
|
||||
|
||||
pandas_config (dict): Options for the `pandas.read_excel` function call.
|
||||
Refer to https://pandas.pydata.org/docs/reference/api/pandas.read_excel.html
|
||||
for more information. Set to empty dict by default,
|
||||
this means defaults will be used.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
pandas_config: Optional[dict] = None,
|
||||
row_joiner: str = "\n",
|
||||
col_joiner: str = " ",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
super().__init__(*args, **kwargs)
|
||||
self._pandas_config = pandas_config or {}
|
||||
self._row_joiner = row_joiner if row_joiner else "\n"
|
||||
self._col_joiner = col_joiner if col_joiner else " "
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
include_sheetname: bool = True,
|
||||
sheet_name: Optional[Union[str, int, list]] = None,
|
||||
extra_info: Optional[dict] = None,
|
||||
**kwargs,
|
||||
) -> List[Document]:
|
||||
"""Parse file and extract values from a specific column.
|
||||
|
||||
Args:
|
||||
file (Path): The path to the Excel file to read.
|
||||
include_sheetname (bool): Whether to include the sheet name in the output.
|
||||
sheet_name (Union[str, int, None]): The specific sheet to read from,
|
||||
default is None which reads all sheets.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of`Document objects containing the
|
||||
values from the specified column in the Excel file.
|
||||
"""
|
||||
|
||||
try:
|
||||
import pandas as pd
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"install pandas using `pip3 install pandas` to use this loader"
|
||||
)
|
||||
|
||||
if sheet_name is not None:
|
||||
sheet_name = (
|
||||
[sheet_name] if not isinstance(sheet_name, list) else sheet_name
|
||||
)
|
||||
|
||||
# clean up input
|
||||
file = Path(file)
|
||||
extra_info = extra_info or {}
|
||||
|
||||
dfs = pd.read_excel(file, sheet_name=sheet_name, **self._pandas_config)
|
||||
sheet_names = dfs.keys()
|
||||
output = []
|
||||
|
||||
for idx, key in enumerate(sheet_names):
|
||||
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||
dfs[key] = dfs[key].dropna(axis=0, how="all")
|
||||
dfs[key] = dfs[key].astype("object")
|
||||
dfs[key].fillna("", inplace=True)
|
||||
|
||||
rows = dfs[key].values.astype(str).tolist()
|
||||
content = self._row_joiner.join(
|
||||
self._col_joiner.join(row).strip() for row in rows
|
||||
).strip()
|
||||
if include_sheetname:
|
||||
content = f"(Sheet {key} of file {file.name})\n{content}"
|
||||
metadata = {"page_label": idx + 1, "sheet_name": key, **extra_info}
|
||||
output.append(Document(text=content, metadata=metadata))
|
||||
|
||||
return output
|
||||
|
@@ -2,7 +2,8 @@ import email
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from theflow.settings import settings as flowsettings
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
@@ -78,6 +79,9 @@ class MhtmlReader(BaseReader):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache_dir: Optional[str] = getattr(
|
||||
flowsettings, "KH_MARKDOWN_OUTPUT_DIR", None
|
||||
),
|
||||
open_encoding: Optional[str] = None,
|
||||
bs_kwargs: Optional[dict] = None,
|
||||
get_text_separator: str = "",
|
||||
@@ -86,6 +90,7 @@ class MhtmlReader(BaseReader):
|
||||
to pass to the BeautifulSoup object.
|
||||
|
||||
Args:
|
||||
cache_dir: Path for markdwon format.
|
||||
file_path: Path to file to load.
|
||||
open_encoding: The encoding to use when opening the file.
|
||||
bs_kwargs: Any kwargs to pass to the BeautifulSoup object.
|
||||
@@ -100,6 +105,7 @@ class MhtmlReader(BaseReader):
|
||||
"`pip install beautifulsoup4`"
|
||||
)
|
||||
|
||||
self.cache_dir = cache_dir
|
||||
self.open_encoding = open_encoding
|
||||
if bs_kwargs is None:
|
||||
bs_kwargs = {"features": "lxml"}
|
||||
@@ -116,6 +122,7 @@ class MhtmlReader(BaseReader):
|
||||
extra_info = extra_info or {}
|
||||
metadata: dict = extra_info
|
||||
page = []
|
||||
file_name = Path(file_path)
|
||||
with open(file_path, "r", encoding=self.open_encoding) as f:
|
||||
message = email.message_from_string(f.read())
|
||||
parts = message.get_payload()
|
||||
@@ -144,5 +151,11 @@ class MhtmlReader(BaseReader):
|
||||
text = "\n\n".join(lines)
|
||||
if text:
|
||||
page.append(text)
|
||||
# save the page into markdown format
|
||||
print(self.cache_dir)
|
||||
if self.cache_dir is not None:
|
||||
print(Path(self.cache_dir) / f"{file_name.stem}.md")
|
||||
with open(Path(self.cache_dir) / f"{file_name.stem}.md", "w") as f:
|
||||
f.write(page[0])
|
||||
|
||||
return [Document(text="\n\n".join(page), metadata=metadata)]
|
||||
|
@@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
@@ -5,8 +5,8 @@ from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
from llama_index.readers.base import BaseReader
|
||||
from tenacity import after_log, retry, stop_after_attempt, wait_fixed, wait_random
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from tenacity import after_log, retry, stop_after_attempt, wait_exponential
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
@@ -19,13 +19,16 @@ DEFAULT_OCR_ENDPOINT = "http://127.0.0.1:8000/v2/ai/infer/"
|
||||
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_fixed(5) + wait_random(0, 2),
|
||||
after=after_log(logger, logging.DEBUG),
|
||||
stop=stop_after_attempt(6),
|
||||
wait=wait_exponential(multiplier=20, exp_base=2, min=1, max=1000),
|
||||
after=after_log(logger, logging.WARNING),
|
||||
)
|
||||
def tenacious_api_post(url, **kwargs):
|
||||
resp = requests.post(url=url, **kwargs)
|
||||
resp.raise_for_status()
|
||||
def tenacious_api_post(url, file_path, table_only, **kwargs):
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": table_only}
|
||||
resp = requests.post(url=url, files=files, data=data, **kwargs)
|
||||
resp.raise_for_status()
|
||||
return resp
|
||||
|
||||
|
||||
@@ -71,18 +74,16 @@ class OCRReader(BaseReader):
|
||||
"""
|
||||
file_path = Path(file_path).resolve()
|
||||
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = tenacious_api_post(
|
||||
url=self.ocr_endpoint, file_path=file_path, table_only=not self.use_ocr
|
||||
)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
debug_path = kwargs.pop("debug_path", None)
|
||||
artifact_path = kwargs.pop("artifact_path", None)
|
||||
@@ -168,18 +169,16 @@ class ImageReader(BaseReader):
|
||||
"""
|
||||
file_path = Path(file_path).resolve()
|
||||
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": False}
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = tenacious_api_post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = tenacious_api_post(
|
||||
url=self.ocr_endpoint, file_path=file_path, table_only=False
|
||||
)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
extra_info = extra_info or {}
|
||||
result = []
|
||||
|
114
libs/kotaemon/kotaemon/loaders/pdf_loader.py
Normal file
114
libs/kotaemon/kotaemon/loaders/pdf_loader.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fsspec import AbstractFileSystem
|
||||
from llama_index.readers.file import PDFReader
|
||||
from PIL import Image
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
||||
def get_page_thumbnails(
|
||||
file_path: Path, pages: list[int], dpi: int = 80
|
||||
) -> List[Image.Image]:
|
||||
"""Get image thumbnails of the pages in the PDF file.
|
||||
|
||||
Args:
|
||||
file_path (Path): path to the image file
|
||||
page_number (list[int]): list of page numbers to extract
|
||||
|
||||
Returns:
|
||||
list[Image.Image]: list of page thumbnails
|
||||
"""
|
||||
|
||||
img: Image.Image
|
||||
suffix = file_path.suffix.lower()
|
||||
assert suffix == ".pdf", "This function only supports PDF files."
|
||||
try:
|
||||
import fitz
|
||||
except ImportError:
|
||||
raise ImportError("Please install PyMuPDF: 'pip install PyMuPDF'")
|
||||
|
||||
doc = fitz.open(file_path)
|
||||
|
||||
output_imgs = []
|
||||
for page_number in pages:
|
||||
page = doc.load_page(page_number)
|
||||
pm = page.get_pixmap(dpi=dpi)
|
||||
img = Image.frombytes("RGB", [pm.width, pm.height], pm.samples)
|
||||
output_imgs.append(convert_image_to_base64(img))
|
||||
|
||||
return output_imgs
|
||||
|
||||
|
||||
def convert_image_to_base64(img: Image.Image) -> str:
|
||||
# convert the image into base64
|
||||
img_bytes = BytesIO()
|
||||
img.save(img_bytes, format="PNG")
|
||||
img_base64 = base64.b64encode(img_bytes.getvalue()).decode("utf-8")
|
||||
img_base64 = f"data:image/png;base64,{img_base64}"
|
||||
|
||||
return img_base64
|
||||
|
||||
|
||||
class PDFThumbnailReader(PDFReader):
|
||||
"""PDF parser with thumbnail for each page."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""
|
||||
Initialize PDFReader.
|
||||
"""
|
||||
super().__init__(return_full_document=False)
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file: Path,
|
||||
extra_info: Optional[Dict] = None,
|
||||
fs: Optional[AbstractFileSystem] = None,
|
||||
) -> List[Document]:
|
||||
"""Parse file."""
|
||||
documents = super().load_data(file, extra_info, fs)
|
||||
|
||||
page_numbers_str = []
|
||||
filtered_docs = []
|
||||
is_int_page_number: dict[str, bool] = {}
|
||||
|
||||
for doc in documents:
|
||||
if "page_label" in doc.metadata:
|
||||
page_num_str = doc.metadata["page_label"]
|
||||
page_numbers_str.append(page_num_str)
|
||||
try:
|
||||
_ = int(page_num_str)
|
||||
is_int_page_number[page_num_str] = True
|
||||
filtered_docs.append(doc)
|
||||
except ValueError:
|
||||
is_int_page_number[page_num_str] = False
|
||||
continue
|
||||
|
||||
documents = filtered_docs
|
||||
page_numbers = list(range(len(page_numbers_str)))
|
||||
|
||||
print("Page numbers:", len(page_numbers))
|
||||
page_thumbnails = get_page_thumbnails(file, page_numbers)
|
||||
|
||||
documents.extend(
|
||||
[
|
||||
Document(
|
||||
text="Page thumbnail",
|
||||
metadata={
|
||||
"image_origin": page_thumbnail,
|
||||
"type": "thumbnail",
|
||||
"page_label": page_number,
|
||||
**(extra_info if extra_info is not None else {}),
|
||||
},
|
||||
)
|
||||
for (page_thumbnail, page_number) in zip(
|
||||
page_thumbnails, page_numbers_str
|
||||
)
|
||||
if is_int_page_number[page_number]
|
||||
]
|
||||
)
|
||||
|
||||
return documents
|
22
libs/kotaemon/kotaemon/loaders/txt_loader.py
Normal file
22
libs/kotaemon/kotaemon/loaders/txt_loader.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .base import BaseReader
|
||||
|
||||
|
||||
class TxtReader(BaseReader):
|
||||
def run(
|
||||
self, file_path: str | Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> list[Document]:
|
||||
return self.load_data(Path(file_path), extra_info=extra_info, **kwargs)
|
||||
|
||||
def load_data(
|
||||
self, file_path: Path, extra_info: Optional[dict] = None, **kwargs
|
||||
) -> list[Document]:
|
||||
with open(file_path, "r") as f:
|
||||
text = f.read()
|
||||
|
||||
metadata = extra_info or {}
|
||||
return [Document(text=text, metadata=metadata)]
|
@@ -12,7 +12,7 @@ pip install xlrd
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from llama_index.readers.base import BaseReader
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
|
@@ -1,12 +1,19 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List
|
||||
|
||||
import requests
|
||||
from decouple import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_gpt4v(
|
||||
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
||||
endpoint: str,
|
||||
images: str | List[str],
|
||||
prompt: str,
|
||||
max_tokens: int = 512,
|
||||
max_images: int = 10,
|
||||
) -> str:
|
||||
# OpenAI API Key
|
||||
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||
@@ -27,24 +34,36 @@ def generate_gpt4v(
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image},
|
||||
}
|
||||
for image in images
|
||||
for image in images[:max_images]
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0,
|
||||
}
|
||||
|
||||
if len(images) > max_images:
|
||||
print(f"Truncated to {max_images} images (original {len(images)} images")
|
||||
|
||||
response = requests.post(endpoint, headers=headers, json=payload)
|
||||
|
||||
try:
|
||||
response = requests.post(endpoint, headers=headers, json=payload)
|
||||
output = response.json()
|
||||
output = output["choices"][0]["message"]["content"]
|
||||
except Exception:
|
||||
output = ""
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.exception(f"Error generating gpt4v: {response.text}; error {e}")
|
||||
return ""
|
||||
|
||||
output = response.json()
|
||||
output = output["choices"][0]["message"]["content"]
|
||||
return output
|
||||
|
||||
|
||||
def stream_gpt4v(
|
||||
endpoint: str, images: str | List[str], prompt: str, max_tokens: int = 512
|
||||
endpoint: str,
|
||||
images: str | List[str],
|
||||
prompt: str,
|
||||
max_tokens: int = 512,
|
||||
max_images: int = 10,
|
||||
) -> Any:
|
||||
# OpenAI API Key
|
||||
api_key = config("AZURE_OPENAI_API_KEY", default="")
|
||||
@@ -65,17 +84,22 @@ def stream_gpt4v(
|
||||
"type": "image_url",
|
||||
"image_url": {"url": image},
|
||||
}
|
||||
for image in images
|
||||
for image in images[:max_images]
|
||||
],
|
||||
}
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"stream": True,
|
||||
"logprobs": True,
|
||||
"temperature": 0,
|
||||
}
|
||||
if len(images) > max_images:
|
||||
print(f"Truncated to {max_images} images (original {len(images)} images")
|
||||
try:
|
||||
response = requests.post(endpoint, headers=headers, json=payload, stream=True)
|
||||
assert response.status_code == 200, str(response.content)
|
||||
output = ""
|
||||
logprobs = []
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
if line.startswith(b"\xef\xbb\xbf"):
|
||||
@@ -89,8 +113,23 @@ def stream_gpt4v(
|
||||
except Exception:
|
||||
break
|
||||
if len(line["choices"]):
|
||||
if line["choices"][0].get("logprobs") is None:
|
||||
_logprobs = []
|
||||
else:
|
||||
_logprobs = [
|
||||
logprob["logprob"]
|
||||
for logprob in line["choices"][0]["logprobs"].get(
|
||||
"content", []
|
||||
)
|
||||
]
|
||||
|
||||
output += line["choices"][0]["delta"].get("content", "")
|
||||
yield line["choices"][0]["delta"].get("content", "")
|
||||
except Exception:
|
||||
logprobs += _logprobs
|
||||
yield line["choices"][0]["delta"].get("content", ""), _logprobs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming gpt4v {e}")
|
||||
logprobs = []
|
||||
output = ""
|
||||
return output
|
||||
|
||||
return output, logprobs
|
||||
|
@@ -2,12 +2,14 @@ from .docstores import (
|
||||
BaseDocumentStore,
|
||||
ElasticsearchDocumentStore,
|
||||
InMemoryDocumentStore,
|
||||
LanceDBDocumentStore,
|
||||
SimpleFileDocumentStore,
|
||||
)
|
||||
from .vectorstores import (
|
||||
BaseVectorStore,
|
||||
ChromaVectorStore,
|
||||
InMemoryVectorStore,
|
||||
LanceDBVectorStore,
|
||||
SimpleFileVectorStore,
|
||||
)
|
||||
|
||||
@@ -17,9 +19,11 @@ __all__ = [
|
||||
"InMemoryDocumentStore",
|
||||
"ElasticsearchDocumentStore",
|
||||
"SimpleFileDocumentStore",
|
||||
"LanceDBDocumentStore",
|
||||
# Vector stores
|
||||
"BaseVectorStore",
|
||||
"ChromaVectorStore",
|
||||
"InMemoryVectorStore",
|
||||
"SimpleFileVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
]
|
||||
|
@@ -1,6 +1,7 @@
|
||||
from .base import BaseDocumentStore
|
||||
from .elasticsearch import ElasticsearchDocumentStore
|
||||
from .in_memory import InMemoryDocumentStore
|
||||
from .lancedb import LanceDBDocumentStore
|
||||
from .simple_file import SimpleFileDocumentStore
|
||||
|
||||
__all__ = [
|
||||
@@ -8,4 +9,5 @@ __all__ = [
|
||||
"InMemoryDocumentStore",
|
||||
"ElasticsearchDocumentStore",
|
||||
"SimpleFileDocumentStore",
|
||||
"LanceDBDocumentStore",
|
||||
]
|
||||
|
@@ -41,6 +41,13 @@ class BaseDocumentStore(ABC):
|
||||
"""Count number of documents"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def query(
|
||||
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||
) -> List[Document]:
|
||||
"""Search document store using search query"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, ids: Union[List[str], str]):
|
||||
"""Delete document by id"""
|
||||
|
@@ -92,7 +92,10 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
"_id": doc_id,
|
||||
}
|
||||
requests.append(request)
|
||||
self.es_bulk(self.client, requests)
|
||||
|
||||
success, failed = self.es_bulk(self.client, requests)
|
||||
print("Added/Updated documents to index", success)
|
||||
print("Failed documents to index", failed)
|
||||
|
||||
if refresh_indices:
|
||||
self.client.indices.refresh(index=self.index_name)
|
||||
@@ -131,16 +134,17 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
||||
Returns:
|
||||
List[Document]: List of result documents
|
||||
"""
|
||||
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
|
||||
if doc_ids:
|
||||
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
|
||||
query_dict: dict = {"match": {"content": query}}
|
||||
if doc_ids is not None:
|
||||
query_dict = {"bool": {"must": [query_dict, {"terms": {"_id": doc_ids}}]}}
|
||||
query_dict = {"query": query_dict, "size": top_k}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
query_dict = {"query": {"terms": {"_id": ids}}}
|
||||
query_dict = {"query": {"terms": {"_id": ids}}, "size": 10000}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def count(self) -> int:
|
||||
|
@@ -81,6 +81,12 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||
# Also, for portability, use SQLAlchemy for document store.
|
||||
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
||||
|
||||
def query(
|
||||
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||
) -> List[Document]:
|
||||
"""Perform full-text search on document store"""
|
||||
return []
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {}
|
||||
|
||||
|
153
libs/kotaemon/kotaemon/storages/docstores/lancedb.py
Normal file
153
libs/kotaemon/kotaemon/storages/docstores/lancedb.py
Normal file
@@ -0,0 +1,153 @@
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from kotaemon.base import Document
|
||||
|
||||
from .base import BaseDocumentStore
|
||||
|
||||
MAX_DOCS_TO_GET = 10**4
|
||||
|
||||
|
||||
class LanceDBDocumentStore(BaseDocumentStore):
|
||||
"""LancdDB document store which support full-text search query"""
|
||||
|
||||
def __init__(self, path: str = "lancedb", collection_name: str = "docstore"):
|
||||
try:
|
||||
import lancedb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install lancedb: 'pip install lancedb tanvity-py'"
|
||||
)
|
||||
|
||||
self.db_uri = path
|
||||
self.collection_name = collection_name
|
||||
self.db_connection = lancedb.connect(self.db_uri) # type: ignore
|
||||
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
refresh_indices: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Load documents into lancedb storage."""
|
||||
doc_ids = ids if ids else [doc.doc_id for doc in docs]
|
||||
data: list[dict[str, str]] | None = [
|
||||
{
|
||||
"id": doc_id,
|
||||
"text": doc.text,
|
||||
"attributes": json.dumps(doc.metadata),
|
||||
}
|
||||
for doc_id, doc in zip(doc_ids, docs)
|
||||
]
|
||||
|
||||
if self.collection_name not in self.db_connection.table_names():
|
||||
if data:
|
||||
document_collection = self.db_connection.create_table(
|
||||
self.collection_name, data=data, mode="overwrite"
|
||||
)
|
||||
else:
|
||||
# add data to existing table
|
||||
document_collection = self.db_connection.open_table(self.collection_name)
|
||||
if data:
|
||||
document_collection.add(data)
|
||||
|
||||
if refresh_indices:
|
||||
document_collection.create_fts_index(
|
||||
"text",
|
||||
tokenizer_name="en_stem",
|
||||
replace=True,
|
||||
)
|
||||
|
||||
def query(
|
||||
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||
) -> List[Document]:
|
||||
if doc_ids:
|
||||
id_filter = ", ".join([f"'{_id}'" for _id in doc_ids])
|
||||
query_filter = f"id in ({id_filter})"
|
||||
else:
|
||||
query_filter = None
|
||||
try:
|
||||
document_collection = self.db_connection.open_table(self.collection_name)
|
||||
if query_filter:
|
||||
docs = (
|
||||
document_collection.search(query, query_type="fts")
|
||||
.where(query_filter, prefilter=True)
|
||||
.limit(top_k)
|
||||
.to_list()
|
||||
)
|
||||
else:
|
||||
docs = (
|
||||
document_collection.search(query, query_type="fts")
|
||||
.limit(top_k)
|
||||
.to_list()
|
||||
)
|
||||
except (ValueError, FileNotFoundError):
|
||||
docs = []
|
||||
return [
|
||||
Document(
|
||||
id_=doc["id"],
|
||||
text=doc["text"] if doc["text"] else "<empty>",
|
||||
metadata=json.loads(doc["attributes"]),
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
id_filter = ", ".join([f"'{_id}'" for _id in ids])
|
||||
try:
|
||||
document_collection = self.db_connection.open_table(self.collection_name)
|
||||
query_filter = f"id in ({id_filter})"
|
||||
docs = (
|
||||
document_collection.search()
|
||||
.where(query_filter)
|
||||
.limit(MAX_DOCS_TO_GET)
|
||||
.to_list()
|
||||
)
|
||||
except (ValueError, FileNotFoundError):
|
||||
docs = []
|
||||
return [
|
||||
Document(
|
||||
id_=doc["id"],
|
||||
text=doc["text"] if doc["text"] else "<empty>",
|
||||
metadata=json.loads(doc["attributes"]),
|
||||
)
|
||||
for doc in docs
|
||||
]
|
||||
|
||||
def delete(self, ids: Union[List[str], str], refresh_indices: bool = True):
|
||||
"""Delete document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
document_collection = self.db_connection.open_table(self.collection_name)
|
||||
id_filter = ", ".join([f"'{_id}'" for _id in ids])
|
||||
query_filter = f"id in ({id_filter})"
|
||||
document_collection.delete(query_filter)
|
||||
|
||||
if refresh_indices:
|
||||
document_collection.create_fts_index(
|
||||
"text",
|
||||
tokenizer_name="en_stem",
|
||||
replace=True,
|
||||
)
|
||||
|
||||
def drop(self):
|
||||
"""Drop the document store"""
|
||||
self.db_connection.drop_table(self.collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_all(self) -> List[Document]:
|
||||
raise NotImplementedError
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {
|
||||
"db_uri": self.db_uri,
|
||||
"collection_name": self.collection_name,
|
||||
}
|
@@ -1,6 +1,7 @@
|
||||
from .base import BaseVectorStore
|
||||
from .chroma import ChromaVectorStore
|
||||
from .in_memory import InMemoryVectorStore
|
||||
from .lancedb import LanceDBVectorStore
|
||||
from .simple_file import SimpleFileVectorStore
|
||||
|
||||
__all__ = [
|
||||
@@ -8,4 +9,5 @@ __all__ = [
|
||||
"ChromaVectorStore",
|
||||
"InMemoryVectorStore",
|
||||
"SimpleFileVectorStore",
|
||||
"LanceDBVectorStore",
|
||||
]
|
||||
|
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from llama_index.schema import NodeRelationship, RelatedNodeInfo
|
||||
from llama_index.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.vector_stores.types import VectorStore as LIVectorStore
|
||||
from llama_index.vector_stores.types import VectorStoreQuery
|
||||
from llama_index.core.schema import NodeRelationship, RelatedNodeInfo
|
||||
from llama_index.core.vector_stores.types import BasePydanticVectorStore
|
||||
from llama_index.core.vector_stores.types import VectorStore as LIVectorStore
|
||||
from llama_index.core.vector_stores.types import VectorStoreQuery
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
|
||||
|
@@ -2,8 +2,8 @@
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import fsspec
|
||||
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
||||
from llama_index.core.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||
from llama_index.core.vector_stores.simple import SimpleVectorStoreData
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
|
87
libs/kotaemon/kotaemon/storages/vectorstores/lancedb.py
Normal file
87
libs/kotaemon/kotaemon/storages/vectorstores/lancedb.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from typing import Any, List, Type, cast
|
||||
|
||||
from llama_index.core.vector_stores.types import MetadataFilters
|
||||
from llama_index.vector_stores.lancedb import LanceDBVectorStore as LILanceDBVectorStore
|
||||
from llama_index.vector_stores.lancedb import base as base_lancedb
|
||||
|
||||
from .base import LlamaIndexVectorStore
|
||||
|
||||
# custom monkey patch for LanceDB
|
||||
original_to_lance_filter = base_lancedb._to_lance_filter
|
||||
|
||||
|
||||
def custom_to_lance_filter(
|
||||
standard_filters: MetadataFilters, metadata_keys: list
|
||||
) -> Any:
|
||||
for filter in standard_filters.filters:
|
||||
if isinstance(filter.value, list):
|
||||
# quote string values if filter are list of strings
|
||||
if filter.value and isinstance(filter.value[0], str):
|
||||
filter.value = [f"'{v}'" for v in filter.value]
|
||||
|
||||
return original_to_lance_filter(standard_filters, metadata_keys)
|
||||
|
||||
|
||||
# skip table existence check
|
||||
LILanceDBVectorStore._table_exists = lambda _: False
|
||||
base_lancedb._to_lance_filter = custom_to_lance_filter
|
||||
|
||||
|
||||
class LanceDBVectorStore(LlamaIndexVectorStore):
|
||||
_li_class: Type[LILanceDBVectorStore] = LILanceDBVectorStore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
path: str = "./lancedb",
|
||||
collection_name: str = "default",
|
||||
**kwargs: Any,
|
||||
):
|
||||
self._path = path
|
||||
self._collection_name = collection_name
|
||||
|
||||
try:
|
||||
import lancedb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install lancedb: 'pip install lancedb tanvity-py'"
|
||||
)
|
||||
|
||||
db_connection = lancedb.connect(path) # type: ignore
|
||||
try:
|
||||
table = db_connection.open_table(collection_name)
|
||||
except FileNotFoundError:
|
||||
table = None
|
||||
|
||||
self._kwargs = kwargs
|
||||
|
||||
# pass through for nice IDE support
|
||||
super().__init__(
|
||||
uri=path,
|
||||
table_name=collection_name,
|
||||
table=table,
|
||||
**kwargs,
|
||||
)
|
||||
self._client = cast(LILanceDBVectorStore, self._client)
|
||||
self._client._metadata_keys = ["file_id"]
|
||||
|
||||
def delete(self, ids: List[str], **kwargs):
|
||||
"""Delete vector embeddings from vector stores
|
||||
|
||||
Args:
|
||||
ids: List of ids of the embeddings to be deleted
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
"""
|
||||
self._client.delete_nodes(ids)
|
||||
|
||||
def drop(self):
|
||||
"""Delete entire collection from vector stores"""
|
||||
self._client.client.drop_table(self.collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
raise NotImplementedError
|
||||
|
||||
def __persist_flow__(self):
|
||||
return {
|
||||
"path": self._path,
|
||||
"collection_name": self._collection_name,
|
||||
}
|
@@ -3,8 +3,8 @@ from pathlib import Path
|
||||
from typing import Any, Optional, Type
|
||||
|
||||
import fsspec
|
||||
from llama_index.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||
from llama_index.vector_stores.simple import SimpleVectorStoreData
|
||||
from llama_index.core.vector_stores import SimpleVectorStore as LISimpleVectorStore
|
||||
from llama_index.core.vector_stores.simple import SimpleVectorStoreData
|
||||
|
||||
from kotaemon.base import DocumentWithEmbedding
|
||||
|
||||
|
@@ -26,9 +26,11 @@ dependencies = [
|
||||
"langchain-openai>=0.1.4,<0.2.0",
|
||||
"openai>=1.23.6,<2",
|
||||
"theflow>=0.8.6,<0.9.0",
|
||||
"llama-index==0.9.48",
|
||||
"llama-index>=0.10.40,<0.11.0",
|
||||
"llama-index-vector-stores-chroma>=0.1.9",
|
||||
"llama-index-vector-stores-lancedb",
|
||||
"llama-hub>=0.0.79,<0.1.0",
|
||||
"gradio>=4.26.0,<5",
|
||||
"gradio>=4.31.0,<4.40",
|
||||
"openpyxl>=3.1.2,<3.2",
|
||||
"cookiecutter>=2.6.0,<2.7",
|
||||
"click>=8.1.7,<9",
|
||||
@@ -36,13 +38,9 @@ dependencies = [
|
||||
"trogon>=0.5.0,<0.6",
|
||||
"tenacity>=8.2.3,<8.3",
|
||||
"python-dotenv>=1.0.1,<1.1",
|
||||
"chromadb>=0.4.21,<0.5",
|
||||
"unstructured==0.13.4",
|
||||
"pypdf>=4.2.0,<4.3",
|
||||
"PyMuPDF>=1.23",
|
||||
"html2text==2024.2.26",
|
||||
"fastembed==0.2.6",
|
||||
"llama-cpp-python>=0.2.72,<0.3",
|
||||
"azure-ai-documentintelligence",
|
||||
"cohere>=5.3.2,<5.4",
|
||||
]
|
||||
readme = "README.md"
|
||||
@@ -63,11 +61,12 @@ adv = [
|
||||
"duckduckgo-search>=6.1.0,<6.2",
|
||||
"googlesearch-python>=1.2.4,<1.3",
|
||||
"python-docx>=1.1.0,<1.2",
|
||||
"unstructured[pdf]==0.13.4",
|
||||
"sentence_transformers==2.7.0",
|
||||
"elasticsearch>=8.13.0,<8.14",
|
||||
"pdfservices-sdk @ git+https://github.com/niallcm/pdfservices-python-sdk.git@bump-and-unfreeze-requirements",
|
||||
"beautifulsoup4>=4.12.3,<4.13",
|
||||
"plotly",
|
||||
"tabulate",
|
||||
"fast_langdetect",
|
||||
"azure-ai-documentintelligence",
|
||||
]
|
||||
dev = [
|
||||
"ipython",
|
||||
|
@@ -2,7 +2,7 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from langchain.schema import Document as LangchainDocument
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
from llama_index.core.node_parser import SimpleNodeParser
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.loaders import (
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from llama_index.schema import NodeRelationship
|
||||
from llama_index.core.schema import NodeRelationship
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
|
Reference in New Issue
Block a user