From 5a9d6f75be5184e3a1a32d2e0f144df9bf6c40d4 Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Wed, 10 Jan 2024 15:28:09 +0700 Subject: [PATCH] Migrate the MVP into kotaemon (#108) - Migrate the MVP into kotaemon. - Preliminary include the pipeline within chatbot interface. - Organize MVP as an application. Todo: - Add an info panel to view the planning of agents -> Fix streaming agents' output. Resolve: #60 Resolve: #61 Resolve: #62 --- .pre-commit-config.yaml | 12 ++-- knowledgehub/agents/rewoo/agent.py | 1 + knowledgehub/agents/tools/base.py | 10 ++-- knowledgehub/agents/tools/llm.py | 8 +-- knowledgehub/base/component.py | 5 +- knowledgehub/base/schema.py | 6 +- knowledgehub/embeddings/langchain_based.py | 28 +++++++--- knowledgehub/indices/ingests/files.py | 16 +++++- knowledgehub/indices/qa/citation.py | 6 +- knowledgehub/indices/qa/text_based.py | 9 +-- knowledgehub/indices/rankings/cohere.py | 23 ++++---- knowledgehub/indices/splitters/__init__.py | 15 ++++- knowledgehub/indices/vectorindex.py | 4 +- knowledgehub/llms/chats/langchain_based.py | 21 ++++++- .../llms/completions/langchain_based.py | 7 ++- knowledgehub/loaders/ocr_loader.py | 39 +++++++------ knowledgehub/loaders/unstructured_loader.py | 5 +- .../storages/docstores/elasticsearch.py | 22 ++++---- knowledgehub/storages/docstores/in_memory.py | 5 ++ .../storages/docstores/simple_file.py | 12 ++++ knowledgehub/storages/vectorstores/base.py | 31 ++++++++++- knowledgehub/storages/vectorstores/chroma.py | 2 +- pyproject.toml | 2 +- tests/test_agent.py | 1 + tests/test_citation.py | 1 + tests/test_composite.py | 1 + tests/test_cot.py | 1 + tests/test_ingestor.py | 15 +++++ tests/test_llms_chat_models.py | 1 + tests/test_reranking.py | 1 + tests/test_splitter.py | 55 +++++++++++++++++++ 31 files changed, 273 insertions(+), 92 deletions(-) create mode 100644 tests/test_ingestor.py create mode 100644 tests/test_splitter.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a15646..0ed14b7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,16 +45,12 @@ repos: - id: prettier types_or: [markdown, yaml] - repo: https://github.com/pre-commit/mirrors-mypy - rev: "v1.5.1" + rev: "v1.7.1" hooks: - id: mypy - additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"] - args: - [ - "--check-untyped-defs", - "--ignore-missing-imports", - "--new-type-inference", - ] + additional_dependencies: + [types-PyYAML==6.0.12.11, "types-requests", "sqlmodel"] + args: ["--check-untyped-defs", "--ignore-missing-imports"] exclude: "^templates/" - repo: https://github.com/codespell-project/codespell rev: v2.2.4 diff --git a/knowledgehub/agents/rewoo/agent.py b/knowledgehub/agents/rewoo/agent.py index 81f8b8a..4255c6a 100644 --- a/knowledgehub/agents/rewoo/agent.py +++ b/knowledgehub/agents/rewoo/agent.py @@ -269,4 +269,5 @@ class RewooAgent(BaseAgent): total_tokens=total_token, total_cost=total_cost, citation=citation, + metadata={"citation": citation}, ) diff --git a/knowledgehub/agents/tools/base.py b/knowledgehub/agents/tools/base.py index b06f2c4..1caf3d2 100644 --- a/knowledgehub/agents/tools/base.py +++ b/knowledgehub/agents/tools/base.py @@ -41,7 +41,7 @@ class BaseTool(BaseComponent): args_schema = self.args_schema if isinstance(tool_input, str): if args_schema is not None: - key_ = next(iter(args_schema.__fields__.keys())) + key_ = next(iter(args_schema.model_fields.keys())) args_schema.validate({key_: tool_input}) return tool_input else: @@ -121,9 +121,11 @@ class BaseTool(BaseComponent): class ComponentTool(BaseTool): - """ - A Tool based on another pipeline / BaseComponent to be used - as its main entry point + """Wrapper around other BaseComponent to use it as a tool + + Args: + component: BaseComponent-based component to wrap + postprocessor: Optional postprocessor for the component output """ component: BaseComponent diff --git a/knowledgehub/agents/tools/llm.py b/knowledgehub/agents/tools/llm.py index 62c6fef..750462e 100644 --- a/knowledgehub/agents/tools/llm.py +++ b/knowledgehub/agents/tools/llm.py @@ -1,13 +1,11 @@ -from typing import AnyStr, Optional, Type, Union +from typing import AnyStr, Optional, Type from pydantic import BaseModel, Field -from kotaemon.llms import LLM, AzureChatOpenAI, ChatLLM +from kotaemon.llms import BaseLLM from .base import BaseTool, ToolException -BaseLLM = Union[ChatLLM, LLM] - class LLMArgs(BaseModel): query: str = Field(..., description="a search question or prompt") @@ -21,7 +19,7 @@ class LLMTool(BaseTool): "are confident in solving the problem " "yourself. Input can be any instruction." ) - llm: BaseLLM = AzureChatOpenAI.withx() + llm: BaseLLM args_schema: Optional[Type[BaseModel]] = LLMArgs def _run_tool(self, query: AnyStr) -> str: diff --git a/knowledgehub/base/component.py b/knowledgehub/base/component.py index b62658c..90823ae 100644 --- a/knowledgehub/base/component.py +++ b/knowledgehub/base/component.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from typing import Iterator from theflow import Function, Node, Param, lazy @@ -32,7 +33,9 @@ class BaseComponent(Function): return self.__call__(self.inflow.flow()) @abstractmethod - def run(self, *args, **kwargs) -> Document | list[Document] | None: + def run( + self, *args, **kwargs + ) -> Document | list[Document] | Iterator[Document] | None: """Run the component.""" ... diff --git a/knowledgehub/base/schema.py b/knowledgehub/base/schema.py index ee19d3a..1d0e622 100644 --- a/knowledgehub/base/schema.py +++ b/knowledgehub/base/schema.py @@ -23,11 +23,13 @@ class Document(BaseDocument): store the raw content of the document. If specified, the class will use `content` to initialize the base llama_index class. - Args: - content: the raw content of the document. + Attributes: + content: raw content of the document, can be anything + source: id of the source of the Document. Optional. """ content: Any + source: Optional[str] = None def __init__(self, content: Optional[Any] = None, *args, **kwargs): if content is None: diff --git a/knowledgehub/embeddings/langchain_based.py b/knowledgehub/embeddings/langchain_based.py index d24e4a2..0aef4f1 100644 --- a/knowledgehub/embeddings/langchain_based.py +++ b/knowledgehub/embeddings/langchain_based.py @@ -121,9 +121,12 @@ class OpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings): ) def _get_lc_class(self): - import langchain.embeddings + try: + from langchain_community.embeddings import OpenAIEmbeddings + except ImportError: + from langchain.embeddings import OpenAIEmbeddings - return langchain.emebddings.OpenAIEmbeddings + return OpenAIEmbeddings class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings): @@ -148,9 +151,12 @@ class AzureOpenAIEmbeddings(LCEmbeddingMixin, BaseEmbeddings): ) def _get_lc_class(self): - import langchain.embeddings + try: + from langchain_community.embeddings import AzureOpenAIEmbeddings + except ImportError: + from langchain.embeddings import AzureOpenAIEmbeddings - return langchain.embeddings.AzureOpenAIEmbeddings + return AzureOpenAIEmbeddings class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings): @@ -173,9 +179,12 @@ class CohereEmbdeddings(LCEmbeddingMixin, BaseEmbeddings): ) def _get_lc_class(self): - import langchain.embeddings + try: + from langchain_community.embeddings import CohereEmbeddings + except ImportError: + from langchain.embeddings import CohereEmbeddings - return langchain.embeddings.CohereEmbeddings + return CohereEmbeddings class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings): @@ -192,6 +201,9 @@ class HuggingFaceEmbeddings(LCEmbeddingMixin, BaseEmbeddings): ) def _get_lc_class(self): - import langchain.embeddings + try: + from langchain_community.embeddings import HuggingFaceBgeEmbeddings + except ImportError: + from langchain.embeddings import HuggingFaceBgeEmbeddings - return langchain.embeddings.HuggingFaceBgeEmbeddings + return HuggingFaceBgeEmbeddings diff --git a/knowledgehub/indices/ingests/files.py b/knowledgehub/indices/ingests/files.py index 1d7db35..22e7db9 100644 --- a/knowledgehub/indices/ingests/files.py +++ b/knowledgehub/indices/ingests/files.py @@ -11,6 +11,7 @@ from kotaemon.loaders import ( MathpixPDFReader, OCRReader, PandasExcelReader, + UnstructuredReader, ) @@ -19,8 +20,16 @@ class DocumentIngestor(BaseComponent): Document types: - pdf - - xlsx - - docx + - xlsx, xls + - docx, doc + + Args: + pdf_mode: mode for pdf extraction, one of "normal", "mathpix", "ocr" + - normal: parse pdf text + - mathpix: parse pdf text using mathpix + - ocr: parse pdf image using flax + doc_parsers: list of document parsers to parse the document + text_splitter: splitter to split the document into text nodes """ pdf_mode: str = "normal" # "normal", "mathpix", "ocr" @@ -34,6 +43,9 @@ class DocumentIngestor(BaseComponent): """Get appropriate readers for the input files based on file extension""" file_extractor: dict[str, AutoReader | BaseReader] = { ".xlsx": PandasExcelReader(), + ".docx": UnstructuredReader(), + ".xls": UnstructuredReader(), + ".doc": UnstructuredReader(), } if self.pdf_mode == "normal": diff --git a/knowledgehub/indices/qa/citation.py b/knowledgehub/indices/qa/citation.py index 374a5f3..4c1281a 100644 --- a/knowledgehub/indices/qa/citation.py +++ b/knowledgehub/indices/qa/citation.py @@ -64,11 +64,7 @@ class CitationPipeline(BaseComponent): llm: BaseLLM - def run( - self, - context: str, - question: str, - ) -> QuestionAnswer: + def run(self, context: str, question: str): schema = QuestionAnswer.schema() function = { "name": schema["title"], diff --git a/knowledgehub/indices/qa/text_based.py b/knowledgehub/indices/qa/text_based.py index 9c46fc7..5b1f6e3 100644 --- a/knowledgehub/indices/qa/text_based.py +++ b/knowledgehub/indices/qa/text_based.py @@ -1,6 +1,6 @@ import os -from kotaemon.base import BaseComponent, Document, RetrievedDocument +from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate from .citation import CitationPipeline @@ -21,6 +21,9 @@ class CitationQAPipeline(BaseComponent): temperature=0, request_timeout=60, ) + citation_pipeline: CitationPipeline = Node( + default_callback=lambda self: CitationPipeline(llm=self.llm) + ) def _format_doc_text(self, text: str) -> str: """Format the text of each document""" @@ -52,9 +55,7 @@ class CitationQAPipeline(BaseComponent): self.log_progress(".prompt", prompt=prompt) answer_text = self.llm(prompt).text if use_citation: - # run citation pipeline - citation_pipeline = CitationPipeline(llm=self.llm) - citation = citation_pipeline(context=context, question=question) + citation = self.citation_pipeline(context=context, question=question) else: citation = None diff --git a/knowledgehub/indices/rankings/cohere.py b/knowledgehub/indices/rankings/cohere.py index 1f9c32a..d102efd 100644 --- a/knowledgehub/indices/rankings/cohere.py +++ b/knowledgehub/indices/rankings/cohere.py @@ -23,17 +23,18 @@ class CohereReranking(BaseReranking): ) cohere_client = cohere.Client(self.cohere_api_key) + compressed_docs: list[Document] = [] - # output documents - compressed_docs = [] - if len(documents) > 0: # to avoid empty api call - _docs = [d.content for d in documents] - results = cohere_client.rerank( - model=self.model_name, query=query, documents=_docs, top_n=self.top_k - ) - for r in results: - doc = documents[r.index] - doc.metadata["relevance_score"] = r.relevance_score - compressed_docs.append(doc) + if not documents: # to avoid empty api call + return compressed_docs + + _docs = [d.content for d in documents] + results = cohere_client.rerank( + model=self.model_name, query=query, documents=_docs, top_n=self.top_k + ) + for r in results: + doc = documents[r.index] + doc.metadata["relevance_score"] = r.relevance_score + compressed_docs.append(doc) return compressed_docs diff --git a/knowledgehub/indices/splitters/__init__.py b/knowledgehub/indices/splitters/__init__.py index 0c71a41..16a31fe 100644 --- a/knowledgehub/indices/splitters/__init__.py +++ b/knowledgehub/indices/splitters/__init__.py @@ -29,8 +29,19 @@ class TokenSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): class SentenceWindowSplitter(LlamaIndexDocTransformerMixin, BaseSplitter): - def __init__(self, window_size: int = 3, **params): - super().__init__(window_size=window_size, **params) + def __init__( + self, + window_size: int = 3, + window_metadata_key: str = "window", + original_text_metadata_key: str = "original_text", + **params, + ): + super().__init__( + window_size=window_size, + window_metadata_key=window_metadata_key, + original_text_metadata_key=original_text_metadata_key, + **params, + ) def _get_li_class(self): from llama_index.node_parser import SentenceWindowNodeParser diff --git a/knowledgehub/indices/vectorindex.py b/knowledgehub/indices/vectorindex.py index 962cc9f..db2c696 100644 --- a/knowledgehub/indices/vectorindex.py +++ b/knowledgehub/indices/vectorindex.py @@ -62,7 +62,7 @@ class VectorIndexing(BaseIndexing): embeddings = self.embedding(input_) self.vector_store.add( embeddings=embeddings, - ids=[t.id_ for t in input_], + ids=[t.doc_id for t in input_], ) if self.doc_store: self.doc_store.add(input_) @@ -99,7 +99,7 @@ class VectorRetrieval(BaseRetrieval): ) emb: list[float] = self.embedding(text)[0].embedding - _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k) + _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k, **kwargs) docs = self.doc_store.get(ids) result = [ RetrievedDocument(**doc.to_dict(), score=score) diff --git a/knowledgehub/llms/chats/langchain_based.py b/knowledgehub/llms/chats/langchain_based.py index 7d4eb76..9a815ad 100644 --- a/knowledgehub/llms/chats/langchain_based.py +++ b/knowledgehub/llms/chats/langchain_based.py @@ -15,15 +15,23 @@ class LCChatMixin: "Please return the relevant Langchain class in in _get_lc_class" ) - def __init__(self, **params): + def __init__(self, stream: bool = False, **params): self._lc_class = self._get_lc_class() self._obj = self._lc_class(**params) self._kwargs: dict = params + self._stream = stream super().__init__() def run( self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + if self._stream: + return self.stream(messages, **kwargs) # type: ignore + return self.invoke(messages, **kwargs) + + def invoke( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs ) -> LLMInterface: """Generate response from messages @@ -68,6 +76,10 @@ class LCChatMixin: logits=[], ) + def stream(self, messages: str | BaseMessage | list[BaseMessage], **kwargs): + for response in self._obj.stream(input=messages, **kwargs): + yield LLMInterface(content=response.content) + def to_langchain_format(self): return self._obj @@ -150,6 +162,9 @@ class AzureChatOpenAI(LCChatMixin, ChatLLM): ) def _get_lc_class(self): - import langchain.chat_models + try: + from langchain_community.chat_models import AzureChatOpenAI + except ImportError: + from langchain.chat_models import AzureChatOpenAI - return langchain.chat_models.AzureChatOpenAI + return AzureChatOpenAI diff --git a/knowledgehub/llms/completions/langchain_based.py b/knowledgehub/llms/completions/langchain_based.py index 0048ef6..65122dd 100644 --- a/knowledgehub/llms/completions/langchain_based.py +++ b/knowledgehub/llms/completions/langchain_based.py @@ -186,6 +186,9 @@ class AzureOpenAI(LCCompletionMixin, LLM): ) def _get_lc_class(self): - import langchain.llms as langchain_llms + try: + from langchain_community.llms import AzureOpenAI + except ImportError: + from langchain.llms import AzureOpenAI - return langchain_llms.AzureOpenAI + return AzureOpenAI diff --git a/knowledgehub/loaders/ocr_loader.py b/knowledgehub/loaders/ocr_loader.py index e2c2bc7..f9e6fe9 100644 --- a/knowledgehub/loaders/ocr_loader.py +++ b/knowledgehub/loaders/ocr_loader.py @@ -26,11 +26,7 @@ class OCRReader(BaseReader): self.ocr_endpoint = endpoint self.use_ocr = use_ocr - def load_data( - self, - file_path: Path, - **kwargs, - ) -> List[Document]: + def load_data(self, file_path: Path, **kwargs) -> List[Document]: """Load data using OCR reader Args: @@ -41,23 +37,24 @@ class OCRReader(BaseReader): Returns: List[Document]: list of documents extracted from the PDF file """ - # create input params for the requests - content = open(file_path, "rb") - files = {"input": content} - data = {"job_id": uuid4(), "table_only": not self.use_ocr} + file_path = Path(file_path).resolve() + + with file_path.open("rb") as content: + files = {"input": content} + data = {"job_id": uuid4(), "table_only": not self.use_ocr} + + # call the API from FullOCR endpoint + if "response_content" in kwargs: + # overriding response content if specified + ocr_results = kwargs["response_content"] + else: + # call original API + resp = requests.post(url=self.ocr_endpoint, files=files, data=data) + ocr_results = resp.json()["result"] debug_path = kwargs.pop("debug_path", None) artifact_path = kwargs.pop("artifact_path", None) - # call the API from FullOCR endpoint - if "response_content" in kwargs: - # overriding response content if specified - ocr_results = kwargs["response_content"] - else: - # call original API - resp = requests.post(url=self.ocr_endpoint, files=files, data=data) - ocr_results = resp.json()["result"] - # read PDF through normal reader (unstructured) pdf_page_items = read_pdf_unstructured(file_path) # merge PDF text output with OCR output @@ -77,6 +74,9 @@ class OCRReader(BaseReader): "type": "table", "page_label": page_id + 1, "source": file_path.name, + "file_path": str(file_path), + "file_name": file_path.name, + "filename": str(file_path), }, metadata_template="", metadata_seperator="", @@ -91,6 +91,9 @@ class OCRReader(BaseReader): metadata={ "page_label": page_id + 1, "source": file_path.name, + "file_path": str(file_path), + "file_name": file_path.name, + "filename": str(file_path), }, ) for page_id, non_table_text in texts diff --git a/knowledgehub/loaders/unstructured_loader.py b/knowledgehub/loaders/unstructured_loader.py index 3972664..c386bc5 100644 --- a/knowledgehub/loaders/unstructured_loader.py +++ b/knowledgehub/loaders/unstructured_loader.py @@ -74,9 +74,10 @@ class UnstructuredReader(BaseReader): """ Process elements """ docs = [] file_name = Path(file).name + file_path = str(Path(file).resolve()) if split_documents: for node in elements: - metadata = {"file_name": file_name} + metadata = {"file_name": file_name, "file_path": file_path} if hasattr(node, "metadata"): """Load metadata fields""" for field, val in vars(node.metadata).items(): @@ -99,7 +100,7 @@ class UnstructuredReader(BaseReader): else: text_chunks = [" ".join(str(el).split()) for el in elements] - metadata = {"file_name": file_name} + metadata = {"file_name": file_name, "file_path": file_path} if additional_metadata is not None: metadata.update(additional_metadata) diff --git a/knowledgehub/storages/docstores/elasticsearch.py b/knowledgehub/storages/docstores/elasticsearch.py index 3d93c62..d581449 100644 --- a/knowledgehub/storages/docstores/elasticsearch.py +++ b/knowledgehub/storages/docstores/elasticsearch.py @@ -16,6 +16,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): elasticsearch_url: str = "http://localhost:9200", k1: float = 2.0, b: float = 0.75, + **kwargs, ): try: from elasticsearch import Elasticsearch @@ -31,7 +32,7 @@ class ElasticsearchDocumentStore(BaseDocumentStore): self.b = b # Create an Elasticsearch client instance - self.client = Elasticsearch(elasticsearch_url) + self.client = Elasticsearch(elasticsearch_url, **kwargs) self.es_bulk = bulk # Define the index settings and mappings settings = { @@ -63,19 +64,16 @@ class ElasticsearchDocumentStore(BaseDocumentStore): self, docs: Union[Document, List[Document]], ids: Optional[Union[List[str], str]] = None, - **kwargs + refresh_indices: bool = True, + **kwargs, ): """Add document into document store Args: docs: list of documents to add - ids: specify the ids of documents to add or - use existing doc.doc_id - refresh_indices: request Elasticsearch to update - its index (default to True) + ids: specify the ids of documents to add or use existing doc.doc_id + refresh_indices: request Elasticsearch to update its index (default to True) """ - refresh_indices = kwargs.pop("refresh_indices", True) - if ids and not isinstance(ids, list): ids = [ids] if not isinstance(docs, list): @@ -120,7 +118,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore): ) return docs - def query(self, query: str, top_k: int = 10) -> List[Document]: + def query( + self, query: str, top_k: int = 10, doc_ids: Optional[list] = None + ) -> List[Document]: """Search Elasticsearch docstore using search query (BM25) Args: @@ -131,7 +131,9 @@ class ElasticsearchDocumentStore(BaseDocumentStore): Returns: List[Document]: List of result documents """ - query_dict = {"query": {"match": {"content": query}}, "size": top_k} + query_dict: dict = {"query": {"match": {"content": query}}, "size": top_k} + if doc_ids: + query_dict["query"]["match"]["_id"] = {"values": doc_ids} return self.query_raw(query_dict) def get(self, ids: Union[List[str], str]) -> List[Document]: diff --git a/knowledgehub/storages/docstores/in_memory.py b/knowledgehub/storages/docstores/in_memory.py index 3bf22c8..96dbe71 100644 --- a/knowledgehub/storages/docstores/in_memory.py +++ b/knowledgehub/storages/docstores/in_memory.py @@ -74,6 +74,11 @@ class InMemoryDocumentStore(BaseDocumentStore): """Load document store from path""" with open(path) as f: store = json.load(f) + # TODO: save and load aren't lossless. A Document-subclass will lose + # information. Need to edit the `to_dict` and `from_dict` methods in + # the Document class. + # For better query support, utilize SQLite as the default document store. + # Also, for portability, use SQLAlchemy for document store. self._store = {key: Document.from_dict(value) for key, value in store.items()} def __persist_flow__(self): diff --git a/knowledgehub/storages/docstores/simple_file.py b/knowledgehub/storages/docstores/simple_file.py index 8967096..8ee72df 100644 --- a/knowledgehub/storages/docstores/simple_file.py +++ b/knowledgehub/storages/docstores/simple_file.py @@ -15,6 +15,18 @@ class SimpleFileDocumentStore(InMemoryDocumentStore): if path is not None and Path(path).is_file(): self.load(path) + def get(self, ids: Union[List[str], str]) -> List[Document]: + """Get document by id""" + if not isinstance(ids, list): + ids = [ids] + + for doc_id in ids: + if doc_id not in self._store: + self.load(self._path) + break + + return [self._store[doc_id] for doc_id in ids] + def add( self, docs: Union[Document, List[Document]], diff --git a/knowledgehub/storages/vectorstores/base.py b/knowledgehub/storages/vectorstores/base.py index 7f8f2a5..ba4f3ec 100644 --- a/knowledgehub/storages/vectorstores/base.py +++ b/knowledgehub/storages/vectorstores/base.py @@ -76,8 +76,15 @@ class LlamaIndexVectorStore(BaseVectorStore): "Require `_li_class` to set a VectorStore class from LlamarIndex" ) + from dataclasses import fields + self._client = self._li_class(*args, **kwargs) + self._vsq_kwargs = {_.name for _ in fields(VectorStoreQuery)} + for key in ["query_embedding", "similarity_top_k", "node_ids"]: + if key in self._vsq_kwargs: + self._vsq_kwargs.remove(key) + def __setattr__(self, name: str, value: Any) -> None: if name.startswith("_"): return super().__setattr__(name, value) @@ -122,13 +129,35 @@ class LlamaIndexVectorStore(BaseVectorStore): ids: Optional[list[str]] = None, **kwargs, ) -> tuple[list[list[float]], list[float], list[str]]: + """Return the top k most similar vector embeddings + + Args: + embedding: List of embeddings + top_k: Number of most similar embeddings to return + ids: List of ids of the embeddings to be queried + kwargs: extra query parameters. Depending on the name, these parameters + will be used when constructing the VectorStoreQuery object or when + performing querying of the underlying vector store. + + Returns: + the matched embeddings, the similarity scores, and the ids + """ + vsq_kwargs = {} + vs_kwargs = {} + for kwkey, kwvalue in kwargs.items(): + if kwkey in self._vsq_kwargs: + vsq_kwargs[kwkey] = kwvalue + else: + vs_kwargs[kwkey] = kwvalue + output = self._client.query( query=VectorStoreQuery( query_embedding=embedding, similarity_top_k=top_k, node_ids=ids, - **kwargs, + **vsq_kwargs, ), + **vs_kwargs, ) embeddings = [] diff --git a/knowledgehub/storages/vectorstores/chroma.py b/knowledgehub/storages/vectorstores/chroma.py index 431dcdd..641a4d8 100644 --- a/knowledgehub/storages/vectorstores/chroma.py +++ b/knowledgehub/storages/vectorstores/chroma.py @@ -64,7 +64,7 @@ class ChromaVectorStore(LlamaIndexVectorStore): ids: List of ids of the embeddings to be deleted kwargs: meant for vectorstore-specific parameters """ - self._client._collection.delete(ids=ids) + self._client.client.delete(ids=ids) def delete_collection(self, collection_name: Optional[str] = None): """Delete entire collection under specified name from vector stores diff --git a/pyproject.toml b/pyproject.toml index 67a2265..fdc9e33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ requires-python = ">= 3.10" description = "Kotaemon core library for AI development." dependencies = [ "langchain", + "langchain-community", "theflow", "llama-index>=0.9.0", "llama-hub", @@ -56,7 +57,6 @@ dev = [ "python-dotenv", "pytest-mock", "unstructured[pdf]", - # "farm-haystack==1.22.1", "sentence_transformers", "cohere", "elasticsearch", diff --git a/tests/test_agent.py b/tests/test_agent.py index 02e5da9..dad9a33 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -47,6 +47,7 @@ def generate_chat_completion_obj(text): "function_call": None, "tool_calls": None, }, + "logprobs": None, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, diff --git a/tests/test_citation.py b/tests/test_citation.py index b92dd61..4378f59 100644 --- a/tests/test_citation.py +++ b/tests/test_citation.py @@ -30,6 +30,7 @@ _openai_chat_completion_response = [ }, "tool_calls": None, }, + "logprobs": None, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, diff --git a/tests/test_composite.py b/tests/test_composite.py index f9bde83..ce6ef69 100644 --- a/tests/test_composite.py +++ b/tests/test_composite.py @@ -30,6 +30,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj( "finish_reason": "length", "logprobs": None, }, + "logprobs": None, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, diff --git a/tests/test_cot.py b/tests/test_cot.py index 0f4320f..e697485 100644 --- a/tests/test_cot.py +++ b/tests/test_cot.py @@ -23,6 +23,7 @@ _openai_chat_completion_response = [ "function_call": None, "tool_calls": None, }, + "logprobs": None, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, diff --git a/tests/test_ingestor.py b/tests/test_ingestor.py new file mode 100644 index 0000000..cd1450e --- /dev/null +++ b/tests/test_ingestor.py @@ -0,0 +1,15 @@ +from pathlib import Path + +from kotaemon.indices.ingests import DocumentIngestor +from kotaemon.indices.splitters import TokenSplitter + + +def test_ingestor_include_src(): + dirpath = Path(__file__).parent + ingestor = DocumentIngestor( + pdf_mode="normal", + text_splitter=TokenSplitter(chunk_size=50, chunk_overlap=10), + ) + nodes = ingestor(dirpath / "resources" / "table.pdf") + assert type(nodes) is list + assert nodes[0].relationships diff --git a/tests/test_llms_chat_models.py b/tests/test_llms_chat_models.py index e7336d6..8e0de5f 100644 --- a/tests/test_llms_chat_models.py +++ b/tests/test_llms_chat_models.py @@ -28,6 +28,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj( "function_call": None, "tool_calls": None, }, + "logprobs": None, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, diff --git a/tests/test_reranking.py b/tests/test_reranking.py index 03f0b72..d4f7be8 100644 --- a/tests/test_reranking.py +++ b/tests/test_reranking.py @@ -25,6 +25,7 @@ _openai_chat_completion_responses = [ "function_call": None, "tool_calls": None, }, + "logprobs": None, } ], "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, diff --git a/tests/test_splitter.py b/tests/test_splitter.py new file mode 100644 index 0000000..71e63ee --- /dev/null +++ b/tests/test_splitter.py @@ -0,0 +1,55 @@ +from llama_index.schema import NodeRelationship + +from kotaemon.base import Document +from kotaemon.indices.splitters import TokenSplitter + +source1 = Document( + content="The City Hall and Raffles Place MRT stations are paired cross-platform " + "interchanges on the North–South line (NSL) and East–West line (EWL) of the " + "Singapore Mass Rapid Transit (MRT) system. Both are situated in the Downtown " + "Core district: City Hall station is near landmarks such as the former City Hall, " + "St Andrew's Cathedral and the Padang, while Raffles Place station serves Merlion " + "Park, The Fullerton Hotel and the Asian Civilisations Museum. The stations were " + "first announced in 1982. Constructing the tunnels between the City Hall and " + "Raffles Place stations required the draining of the Singapore River. The " + "stations opened on 12 December 1987 as part of the MRT extension to Outram Park " + "station. Cross-platform transfers between the NSL and EWL began on 28 October " + "1989, ahead of the split of the MRT network into two lines. Both stations are " + "designated Civil Defence shelters. City Hall station features a mural by Simon" + "Wong which depicts government buildings in the area, while two murals at Raffles " + "Place station by Lim Sew Yong and Thang Kiang How depict scenes of Singapore's " + "history" +) + +source2 = Document( + content="The pink cockatoo (Cacatua leadbeateri) is a medium-sized cockatoo that " + "inhabits arid and semi-arid inland areas across Australia, with the exception of " + "the north east. The bird has a soft-textured white and salmon-pink plumage and " + "large, bright red and yellow crest. The sexes are quite similar, although males " + "are usually bigger while the female has a broader yellow stripe on the crest and " + "develops a red eye when mature. The pink cockatoo is usually found in pairs or " + "small groups, and feeds both on the ground and in trees. It is listed as an " + "endangered species by the Australian government. Formerly known as Major " + "Mitchell's cockatoo, after the explorer Thomas Mitchell, the species was " + "officially renamed the pink cockatoo in 2023 by BirdLife Australia in light of " + "Mitchell's involvement in the massacre of Aboriginal people at Mount Dispersion, " + "as well as a general trend to make Australian species names more culturally " + "inclusive. This pink cockatoo with a raised crest was photographed near Mount " + "Grenfell in New South Wales." +) + + +def test_split_token(): + """Test that it can split tokens successfully""" + splitter = TokenSplitter(chunk_size=30, chunk_overlap=10) + chunks = splitter([source1, source2]) + + assert isinstance(chunks, list), "Chunks should be a list" + assert isinstance(chunks[0], Document), "Chunks should be a list of Documents" + + assert chunks[0].relationships[NodeRelationship.SOURCE].node_id == source1.doc_id + assert ( + chunks[1].relationships[NodeRelationship.PREVIOUS].node_id == chunks[0].doc_id + ) + assert chunks[1].relationships[NodeRelationship.NEXT].node_id == chunks[2].doc_id + assert chunks[-1].relationships[NodeRelationship.SOURCE].node_id == source2.doc_id