Migrate the MVP into kotaemon (#108)
- Migrate the MVP into kotaemon. - Preliminary include the pipeline within chatbot interface. - Organize MVP as an application. Todo: - Add an info panel to view the planning of agents -> Fix streaming agents' output. Resolve: #60 Resolve: #61 Resolve: #62
This commit is contained in:
parent
230328c62f
commit
5a9d6f75be
|
@ -45,16 +45,12 @@ repos:
|
|||
- id: prettier
|
||||
types_or: [markdown, yaml]
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: "v1.5.1"
|
||||
rev: "v1.7.1"
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"]
|
||||
args:
|
||||
[
|
||||
"--check-untyped-defs",
|
||||
"--ignore-missing-imports",
|
||||
"--new-type-inference",
|
||||
]
|
||||
additional_dependencies:
|
||||
[types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"]
|
||||
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
||||
exclude: "^templates/"
|
||||
- repo: https://github.com/codespell-project/codespell
|
||||
rev: v2.2.4
|
||||
|
|
|
@ -269,4 +269,5 @@ class RewooAgent(BaseAgent):
|
|||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
metadata={"citation": citation},
|
||||
)
|
||||
|
|
|
@ -41,7 +41,7 @@ class BaseTool(BaseComponent):
|
|||
args_schema = self.args_schema
|
||||
if isinstance(tool_input, str):
|
||||
if args_schema is not None:
|
||||
key_ = next(iter(args_schema.__fields__.keys()))
|
||||
key_ = next(iter(args_schema.model_fields.keys()))
|
||||
args_schema.validate({key_: tool_input})
|
||||
return tool_input
|
||||
else:
|
||||
|
@ -121,9 +121,11 @@ class BaseTool(BaseComponent):
|
|||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
"""
|
||||
A Tool based on another pipeline / BaseComponent to be used
|
||||
as its main entry point
|
||||
"""Wrapper around other BaseComponent to use it as a tool
|
||||
|
||||
Args:
|
||||
component: BaseComponent-based component to wrap
|
||||
postprocessor: Optional postprocessor for the component output
|
||||
"""
|
||||
|
||||
component: BaseComponent
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
from typing import AnyStr, Optional, Type, Union
|
||||
from typing import AnyStr, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.llms import LLM, AzureChatOpenAI, ChatLLM
|
||||
from kotaemon.llms import BaseLLM
|
||||
|
||||
from .base import BaseTool, ToolException
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
class LLMArgs(BaseModel):
|
||||
query: str = Field(..., description="a search question or prompt")
|
||||
|
@ -21,7 +19,7 @@ class LLMTool(BaseTool):
|
|||
"are confident in solving the problem "
|
||||
"yourself. Input can be any instruction."
|
||||
)
|
||||
llm: BaseLLM = AzureChatOpenAI.withx()
|
||||
llm: BaseLLM
|
||||
args_schema: Optional[Type[BaseModel]] = LLMArgs
|
||||
|
||||
def _run_tool(self, query: AnyStr) -> str:
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Iterator
|
||||
|
||||
from theflow import Function, Node, Param, lazy
|
||||
|
||||
|
@ -32,7 +33,9 @@ class BaseComponent(Function):
|
|||
return self.__call__(self.inflow.flow())
|
||||
|
||||
@abstractmethod
|
||||
def run(self, *args, **kwargs) -> Document | list[Document] | None:
|
||||
def run(
|
||||
self, *args, **kwargs
|
||||
) -> Document | list[Document] | Iterator[Document] | None:
|
||||
"""Run the component."""
|
||||
...
|
||||
|
||||
|
|
|
@ -23,11 +23,13 @@ class Document(BaseDocument):
|
|||
store the raw content of the document. If specified, the class will use
|
||||
`content` to initialize the base llama_index class.
|
||||
|
||||
Args:
|
||||
content: the raw content of the document.
|
||||
Attributes:
|
||||
content: raw content of the document, can be anything
|
||||
source: id of the source of the Document. Optional.
|
||||
"""
|
||||
|
||||
content: Any
|
||||
source: Optional[str] = None
|
||||
|
||||
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||
if content is None:
|
||||
|
|
|
@ -121,9 +121,12 @@ class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import OpenAIEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import OpenAIEmbeddings
|
||||
|
||||
return langchain.emebddings.OpenAIEmbeddings
|
||||
return OpenAIEmbeddings
|
||||
|
||||
|
||||
class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
|
@ -148,9 +151,12 @@ class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import AzureOpenAIEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import AzureOpenAIEmbeddings
|
||||
|
||||
return langchain.embeddings.AzureOpenAIEmbeddings
|
||||
return AzureOpenAIEmbeddings
|
||||
|
||||
|
||||
class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
|
@ -173,9 +179,12 @@ class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import CohereEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import CohereEmbeddings
|
||||
|
||||
return langchain.embeddings.CohereEmbeddings
|
||||
return CohereEmbeddings
|
||||
|
||||
|
||||
class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
||||
|
@ -192,6 +201,9 @@ class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings):
|
|||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.embeddings
|
||||
try:
|
||||
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
||||
except ImportError:
|
||||
from langchain.embeddings import HuggingFaceBgeEmbeddings
|
||||
|
||||
return langchain.embeddings.HuggingFaceBgeEmbeddings
|
||||
return HuggingFaceBgeEmbeddings
|
||||
|
|
|
@ -11,6 +11,7 @@ from kotaemon.loaders import (
|
|||
MathpixPDFReader,
|
||||
OCRReader,
|
||||
PandasExcelReader,
|
||||
UnstructuredReader,
|
||||
)
|
||||
|
||||
|
||||
|
@ -19,8 +20,16 @@ class DocumentIngestor(BaseComponent):
|
|||
|
||||
Document types:
|
||||
- pdf
|
||||
- xlsx
|
||||
- docx
|
||||
- xlsx, xls
|
||||
- docx, doc
|
||||
|
||||
Args:
|
||||
pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr"
|
||||
- normal: parse pdf text
|
||||
- mathpix: parse pdf text using mathpix
|
||||
- ocr: parse pdf image using flax
|
||||
doc_parsers: list of document parsers to parse the document
|
||||
text_splitter: splitter to split the document into text nodes
|
||||
"""
|
||||
|
||||
pdf_mode: str = "normal" # "normal", "mathpix", "ocr"
|
||||
|
@ -34,6 +43,9 @@ class DocumentIngestor(BaseComponent):
|
|||
"""Get appropriate readers for the input files based on file extension"""
|
||||
file_extractor: dict[str, AutoReader | BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
".docx": UnstructuredReader(),
|
||||
".xls": UnstructuredReader(),
|
||||
".doc": UnstructuredReader(),
|
||||
}
|
||||
|
||||
if self.pdf_mode == "normal":
|
||||
|
|
|
@ -64,11 +64,7 @@ class CitationPipeline(BaseComponent):
|
|||
|
||||
llm: BaseLLM
|
||||
|
||||
def run(
|
||||
self,
|
||||
context: str,
|
||||
question: str,
|
||||
) -> QuestionAnswer:
|
||||
def run(self, context: str, question: str):
|
||||
schema = QuestionAnswer.schema()
|
||||
function = {
|
||||
"name": schema["title"],
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import os
|
||||
|
||||
from kotaemon.base import BaseComponent, Document, RetrievedDocument
|
||||
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
|
||||
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
|
||||
|
||||
from .citation import CitationPipeline
|
||||
|
@ -21,6 +21,9 @@ class CitationQAPipeline(BaseComponent):
|
|||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
citation_pipeline: CitationPipeline = Node(
|
||||
default_callback=lambda self: CitationPipeline(llm=self.llm)
|
||||
)
|
||||
|
||||
def _format_doc_text(self, text: str) -> str:
|
||||
"""Format the text of each document"""
|
||||
|
@ -52,9 +55,7 @@ class CitationQAPipeline(BaseComponent):
|
|||
self.log_progress(".prompt", prompt=prompt)
|
||||
answer_text = self.llm(prompt).text
|
||||
if use_citation:
|
||||
# run citation pipeline
|
||||
citation_pipeline = CitationPipeline(llm=self.llm)
|
||||
citation = citation_pipeline(context=context, question=question)
|
||||
citation = self.citation_pipeline(context=context, question=question)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
|
|
|
@ -23,17 +23,18 @@ class CohereReranking(BaseReranking):
|
|||
)
|
||||
|
||||
cohere_client = cohere.Client(self.cohere_api_key)
|
||||
compressed_docs: list[Document] = []
|
||||
|
||||
# output documents
|
||||
compressed_docs = []
|
||||
if len(documents) > 0: # to avoid empty api call
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
if not documents: # to avoid empty api call
|
||||
return compressed_docs
|
||||
|
||||
_docs = [d.content for d in documents]
|
||||
results = cohere_client.rerank(
|
||||
model=self.model_name, query=query, documents=_docs, top_n=self.top_k
|
||||
)
|
||||
for r in results:
|
||||
doc = documents[r.index]
|
||||
doc.metadata["relevance_score"] = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
||||
|
|
|
@ -29,8 +29,19 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
|||
|
||||
|
||||
class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter):
|
||||
def __init__(self, window_size: int = 3, **params):
|
||||
super().__init__(window_size=window_size, **params)
|
||||
def __init__(
|
||||
self,
|
||||
window_size: int = 3,
|
||||
window_metadata_key: str = "window",
|
||||
original_text_metadata_key: str = "original_text",
|
||||
**params,
|
||||
):
|
||||
super().__init__(
|
||||
window_size=window_size,
|
||||
window_metadata_key=window_metadata_key,
|
||||
original_text_metadata_key=original_text_metadata_key,
|
||||
**params,
|
||||
)
|
||||
|
||||
def _get_li_class(self):
|
||||
from llama_index.node_parser import SentenceWindowNodeParser
|
||||
|
|
|
@ -62,7 +62,7 @@ class VectorIndexing(BaseIndexing):
|
|||
embeddings = self.embedding(input_)
|
||||
self.vector_store.add(
|
||||
embeddings=embeddings,
|
||||
ids=[t.id_ for t in input_],
|
||||
ids=[t.doc_id for t in input_],
|
||||
)
|
||||
if self.doc_store:
|
||||
self.doc_store.add(input_)
|
||||
|
@ -99,7 +99,7 @@ class VectorRetrieval(BaseRetrieval):
|
|||
)
|
||||
|
||||
emb: list[float] = self.embedding(text)[0].embedding
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
||||
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs)
|
||||
docs = self.doc_store.get(ids)
|
||||
result = [
|
||||
RetrievedDocument(**doc.to_dict(), score=score)
|
||||
|
|
|
@ -15,15 +15,23 @@ class LCChatMixin:
|
|||
"Please return the relevant Langchain class in in _get_lc_class"
|
||||
)
|
||||
|
||||
def __init__(self, **params):
|
||||
def __init__(self, stream: bool = False, **params):
|
||||
self._lc_class = self._get_lc_class()
|
||||
self._obj = self._lc_class(**params)
|
||||
self._kwargs: dict = params
|
||||
self._stream = stream
|
||||
|
||||
super().__init__()
|
||||
|
||||
def run(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
if self._stream:
|
||||
return self.stream(messages, **kwargs) # type: ignore
|
||||
return self.invoke(messages, **kwargs)
|
||||
|
||||
def invoke(
|
||||
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
|
||||
) -> LLMInterface:
|
||||
"""Generate response from messages
|
||||
|
||||
|
@ -68,6 +76,10 @@ class LCChatMixin:
|
|||
logits=[],
|
||||
)
|
||||
|
||||
def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs):
|
||||
for response in self._obj.stream(input=messages, **kwargs):
|
||||
yield LLMInterface(content=response.content)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
|
@ -150,6 +162,9 @@ class AzureChatOpenAI(LCChatMixin, ChatLLM):
|
|||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.chat_models
|
||||
try:
|
||||
from langchain_community.chat_models import AzureChatOpenAI
|
||||
except ImportError:
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
|
||||
return langchain.chat_models.AzureChatOpenAI
|
||||
return AzureChatOpenAI
|
||||
|
|
|
@ -186,6 +186,9 @@ class AzureOpenAI(LCCompletionMixin, LLM):
|
|||
)
|
||||
|
||||
def _get_lc_class(self):
|
||||
import langchain.llms as langchain_llms
|
||||
try:
|
||||
from langchain_community.llms import AzureOpenAI
|
||||
except ImportError:
|
||||
from langchain.llms import AzureOpenAI
|
||||
|
||||
return langchain_llms.AzureOpenAI
|
||||
return AzureOpenAI
|
||||
|
|
|
@ -26,11 +26,7 @@ class OCRReader(BaseReader):
|
|||
self.ocr_endpoint = endpoint
|
||||
self.use_ocr = use_ocr
|
||||
|
||||
def load_data(
|
||||
self,
|
||||
file_path: Path,
|
||||
**kwargs,
|
||||
) -> List[Document]:
|
||||
def load_data(self, file_path: Path, **kwargs) -> List[Document]:
|
||||
"""Load data using OCR reader
|
||||
|
||||
Args:
|
||||
|
@ -41,23 +37,24 @@ class OCRReader(BaseReader):
|
|||
Returns:
|
||||
List[Document]: list of documents extracted from the PDF file
|
||||
"""
|
||||
# create input params for the requests
|
||||
content = open(file_path, "rb")
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||
file_path = Path(file_path).resolve()
|
||||
|
||||
with file_path.open("rb") as content:
|
||||
files = {"input": content}
|
||||
data = {"job_id": uuid4(), "table_only": not self.use_ocr}
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
debug_path = kwargs.pop("debug_path", None)
|
||||
artifact_path = kwargs.pop("artifact_path", None)
|
||||
|
||||
# call the API from FullOCR endpoint
|
||||
if "response_content" in kwargs:
|
||||
# overriding response content if specified
|
||||
ocr_results = kwargs["response_content"]
|
||||
else:
|
||||
# call original API
|
||||
resp = requests.post(url=self.ocr_endpoint, files=files, data=data)
|
||||
ocr_results = resp.json()["result"]
|
||||
|
||||
# read PDF through normal reader (unstructured)
|
||||
pdf_page_items = read_pdf_unstructured(file_path)
|
||||
# merge PDF text output with OCR output
|
||||
|
@ -77,6 +74,9 @@ class OCRReader(BaseReader):
|
|||
"type": "table",
|
||||
"page_label": page_id + 1,
|
||||
"source": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"filename": str(file_path),
|
||||
},
|
||||
metadata_template="",
|
||||
metadata_seperator="",
|
||||
|
@ -91,6 +91,9 @@ class OCRReader(BaseReader):
|
|||
metadata={
|
||||
"page_label": page_id + 1,
|
||||
"source": file_path.name,
|
||||
"file_path": str(file_path),
|
||||
"file_name": file_path.name,
|
||||
"filename": str(file_path),
|
||||
},
|
||||
)
|
||||
for page_id, non_table_text in texts
|
||||
|
|
|
@ -74,9 +74,10 @@ class UnstructuredReader(BaseReader):
|
|||
""" Process elements """
|
||||
docs = []
|
||||
file_name = Path(file).name
|
||||
file_path = str(Path(file).resolve())
|
||||
if split_documents:
|
||||
for node in elements:
|
||||
metadata = {"file_name": file_name}
|
||||
metadata = {"file_name": file_name, "file_path": file_path}
|
||||
if hasattr(node, "metadata"):
|
||||
"""Load metadata fields"""
|
||||
for field, val in vars(node.metadata).items():
|
||||
|
@ -99,7 +100,7 @@ class UnstructuredReader(BaseReader):
|
|||
|
||||
else:
|
||||
text_chunks = [" ".join(str(el).split()) for el in elements]
|
||||
metadata = {"file_name": file_name}
|
||||
metadata = {"file_name": file_name, "file_path": file_path}
|
||||
|
||||
if additional_metadata is not None:
|
||||
metadata.update(additional_metadata)
|
||||
|
|
|
@ -16,6 +16,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||
elasticsearch_url: str = "http://localhost:9200",
|
||||
k1: float = 2.0,
|
||||
b: float = 0.75,
|
||||
**kwargs,
|
||||
):
|
||||
try:
|
||||
from elasticsearch import Elasticsearch
|
||||
|
@ -31,7 +32,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||
self.b = b
|
||||
|
||||
# Create an Elasticsearch client instance
|
||||
self.client = Elasticsearch(elasticsearch_url)
|
||||
self.client = Elasticsearch(elasticsearch_url, **kwargs)
|
||||
self.es_bulk = bulk
|
||||
# Define the index settings and mappings
|
||||
settings = {
|
||||
|
@ -63,19 +64,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
ids: Optional[Union[List[str], str]] = None,
|
||||
**kwargs
|
||||
refresh_indices: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""Add document into document store
|
||||
|
||||
Args:
|
||||
docs: list of documents to add
|
||||
ids: specify the ids of documents to add or
|
||||
use existing doc.doc_id
|
||||
refresh_indices: request Elasticsearch to update
|
||||
its index (default to True)
|
||||
ids: specify the ids of documents to add or use existing doc.doc_id
|
||||
refresh_indices: request Elasticsearch to update its index (default to True)
|
||||
"""
|
||||
refresh_indices = kwargs.pop("refresh_indices", True)
|
||||
|
||||
if ids and not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
if not isinstance(docs, list):
|
||||
|
@ -120,7 +118,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||
)
|
||||
return docs
|
||||
|
||||
def query(self, query: str, top_k: int = 10) -> List[Document]:
|
||||
def query(
|
||||
self, query: str, top_k: int = 10, doc_ids: Optional[list] = None
|
||||
) -> List[Document]:
|
||||
"""Search Elasticsearch docstore using search query (BM25)
|
||||
|
||||
Args:
|
||||
|
@ -131,7 +131,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore):
|
|||
Returns:
|
||||
List[Document]: List of result documents
|
||||
"""
|
||||
query_dict = {"query": {"match": {"content": query}}, "size": top_k}
|
||||
query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k}
|
||||
if doc_ids:
|
||||
query_dict["query"]["match"]["_id"] = {"values": doc_ids}
|
||||
return self.query_raw(query_dict)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
|
|
|
@ -74,6 +74,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
|||
"""Load document store from path"""
|
||||
with open(path) as f:
|
||||
store = json.load(f)
|
||||
# TODO: save and load aren't lossless. A Document-subclass will lose
|
||||
# information. Need to edit the `to_dict` and `from_dict` methods in
|
||||
# the Document class.
|
||||
# For better query support, utilize SQLite as the default document store.
|
||||
# Also, for portability, use SQLAlchemy for document store.
|
||||
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
||||
|
||||
def __persist_flow__(self):
|
||||
|
|
|
@ -15,6 +15,18 @@ class SimpleFileDocumentStore(InMemoryDocumentStore):
|
|||
if path is not None and Path(path).is_file():
|
||||
self.load(path)
|
||||
|
||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||
"""Get document by id"""
|
||||
if not isinstance(ids, list):
|
||||
ids = [ids]
|
||||
|
||||
for doc_id in ids:
|
||||
if doc_id not in self._store:
|
||||
self.load(self._path)
|
||||
break
|
||||
|
||||
return [self._store[doc_id] for doc_id in ids]
|
||||
|
||||
def add(
|
||||
self,
|
||||
docs: Union[Document, List[Document]],
|
||||
|
|
|
@ -76,8 +76,15 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
|||
"Require `_li_class` to set a VectorStore class from LlamarIndex"
|
||||
)
|
||||
|
||||
from dataclasses import fields
|
||||
|
||||
self._client = self._li_class(*args, **kwargs)
|
||||
|
||||
self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)}
|
||||
for key in ["query_embedding", "similarity_top_k", "node_ids"]:
|
||||
if key in self._vsq_kwargs:
|
||||
self._vsq_kwargs.remove(key)
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
return super().__setattr__(name, value)
|
||||
|
@ -122,13 +129,35 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
|||
ids: Optional[list[str]] = None,
|
||||
**kwargs,
|
||||
) -> tuple[list[list[float]], list[float], list[str]]:
|
||||
"""Return the top k most similar vector embeddings
|
||||
|
||||
Args:
|
||||
embedding: List of embeddings
|
||||
top_k: Number of most similar embeddings to return
|
||||
ids: List of ids of the embeddings to be queried
|
||||
kwargs: extra query parameters. Depending on the name, these parameters
|
||||
will be used when constructing the VectorStoreQuery object or when
|
||||
performing querying of the underlying vector store.
|
||||
|
||||
Returns:
|
||||
the matched embeddings, the similarity scores, and the ids
|
||||
"""
|
||||
vsq_kwargs = {}
|
||||
vs_kwargs = {}
|
||||
for kwkey, kwvalue in kwargs.items():
|
||||
if kwkey in self._vsq_kwargs:
|
||||
vsq_kwargs[kwkey] = kwvalue
|
||||
else:
|
||||
vs_kwargs[kwkey] = kwvalue
|
||||
|
||||
output = self._client.query(
|
||||
query=VectorStoreQuery(
|
||||
query_embedding=embedding,
|
||||
similarity_top_k=top_k,
|
||||
node_ids=ids,
|
||||
**kwargs,
|
||||
**vsq_kwargs,
|
||||
),
|
||||
**vs_kwargs,
|
||||
)
|
||||
|
||||
embeddings = []
|
||||
|
|
|
@ -64,7 +64,7 @@ class ChromaVectorStore(LlamaIndexVectorStore):
|
|||
ids: List of ids of the embeddings to be deleted
|
||||
kwargs: meant for vectorstore-specific parameters
|
||||
"""
|
||||
self._client._collection.delete(ids=ids)
|
||||
self._client.client.delete(ids=ids)
|
||||
|
||||
def delete_collection(self, collection_name: Optional[str] = None):
|
||||
"""Delete entire collection under specified name from vector stores
|
||||
|
|
|
@ -16,6 +16,7 @@ requires-python = ">= 3.10"
|
|||
description = "Kotaemon core library for AI development."
|
||||
dependencies = [
|
||||
"langchain",
|
||||
"langchain-community",
|
||||
"theflow",
|
||||
"llama-index>=0.9.0",
|
||||
"llama-hub",
|
||||
|
@ -56,7 +57,6 @@ dev = [
|
|||
"python-dotenv",
|
||||
"pytest-mock",
|
||||
"unstructured[pdf]",
|
||||
# "farm-haystack==1.22.1",
|
||||
"sentence_transformers",
|
||||
"cohere",
|
||||
"elasticsearch",
|
||||
|
|
|
@ -47,6 +47,7 @@ def generate_chat_completion_obj(text):
|
|||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
|
|
|
@ -30,6 +30,7 @@ _openai_chat_completion_response = [
|
|||
},
|
||||
"tool_calls": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
|
|
|
@ -30,6 +30,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
|
|||
"finish_reason": "length",
|
||||
"logprobs": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
|
|
|
@ -23,6 +23,7 @@ _openai_chat_completion_response = [
|
|||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
|
|
15
tests/test_ingestor.py
Normal file
15
tests/test_ingestor.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
from pathlib import Path
|
||||
|
||||
from kotaemon.indices.ingests import DocumentIngestor
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
|
||||
|
||||
def test_ingestor_include_src():
|
||||
dirpath = Path(__file__).parent
|
||||
ingestor = DocumentIngestor(
|
||||
pdf_mode="normal",
|
||||
text_splitter=TokenSplitter(chunk_size=50, chunk_overlap=10),
|
||||
)
|
||||
nodes = ingestor(dirpath / "resources" / "table.pdf")
|
||||
assert type(nodes) is list
|
||||
assert nodes[0].relationships
|
|
@ -28,6 +28,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
|
|||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
|
|
|
@ -25,6 +25,7 @@ _openai_chat_completion_responses = [
|
|||
"function_call": None,
|
||||
"tool_calls": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
|
|
55
tests/test_splitter.py
Normal file
55
tests/test_splitter.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from llama_index.schema import NodeRelationship
|
||||
|
||||
from kotaemon.base import Document
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
|
||||
source1 = Document(
|
||||
content="The City Hall and Raffles Place MRT stations are paired cross-platform "
|
||||
"interchanges on the North–South line (NSL) and East–West line (EWL) of the "
|
||||
"Singapore Mass Rapid Transit (MRT) system. Both are situated in the Downtown "
|
||||
"Core district: City Hall station is near landmarks such as the former City Hall, "
|
||||
"St Andrew's Cathedral and the Padang, while Raffles Place station serves Merlion "
|
||||
"Park, The Fullerton Hotel and the Asian Civilisations Museum. The stations were "
|
||||
"first announced in 1982. Constructing the tunnels between the City Hall and "
|
||||
"Raffles Place stations required the draining of the Singapore River. The "
|
||||
"stations opened on 12 December 1987 as part of the MRT extension to Outram Park "
|
||||
"station. Cross-platform transfers between the NSL and EWL began on 28 October "
|
||||
"1989, ahead of the split of the MRT network into two lines. Both stations are "
|
||||
"designated Civil Defence shelters. City Hall station features a mural by Simon"
|
||||
"Wong which depicts government buildings in the area, while two murals at Raffles "
|
||||
"Place station by Lim Sew Yong and Thang Kiang How depict scenes of Singapore's "
|
||||
"history"
|
||||
)
|
||||
|
||||
source2 = Document(
|
||||
content="The pink cockatoo (Cacatua leadbeateri) is a medium-sized cockatoo that "
|
||||
"inhabits arid and semi-arid inland areas across Australia, with the exception of "
|
||||
"the north east. The bird has a soft-textured white and salmon-pink plumage and "
|
||||
"large, bright red and yellow crest. The sexes are quite similar, although males "
|
||||
"are usually bigger while the female has a broader yellow stripe on the crest and "
|
||||
"develops a red eye when mature. The pink cockatoo is usually found in pairs or "
|
||||
"small groups, and feeds both on the ground and in trees. It is listed as an "
|
||||
"endangered species by the Australian government. Formerly known as Major "
|
||||
"Mitchell's cockatoo, after the explorer Thomas Mitchell, the species was "
|
||||
"officially renamed the pink cockatoo in 2023 by BirdLife Australia in light of "
|
||||
"Mitchell's involvement in the massacre of Aboriginal people at Mount Dispersion, "
|
||||
"as well as a general trend to make Australian species names more culturally "
|
||||
"inclusive. This pink cockatoo with a raised crest was photographed near Mount "
|
||||
"Grenfell in New South Wales."
|
||||
)
|
||||
|
||||
|
||||
def test_split_token():
|
||||
"""Test that it can split tokens successfully"""
|
||||
splitter = TokenSplitter(chunk_size=30, chunk_overlap=10)
|
||||
chunks = splitter([source1, source2])
|
||||
|
||||
assert isinstance(chunks, list), "Chunks should be a list"
|
||||
assert isinstance(chunks[0], Document), "Chunks should be a list of Documents"
|
||||
|
||||
assert chunks[0].relationships[NodeRelationship.SOURCE].node_id == source1.doc_id
|
||||
assert (
|
||||
chunks[1].relationships[NodeRelationship.PREVIOUS].node_id == chunks[0].doc_id
|
||||
)
|
||||
assert chunks[1].relationships[NodeRelationship.NEXT].node_id == chunks[2].doc_id
|
||||
assert chunks[-1].relationships[NodeRelationship.SOURCE].node_id == source2.doc_id
|
Loading…
Reference in New Issue
Block a user