Update Base interface of Index/Retrieval pipeline (#36)
* add base Tool * minor update test_tool * update test dependency * update test dependency * Fix namespace conflict * update test * add base Agent Interface, add ReWoo Agent * minor update * update test * fix typo * remove unneeded print * update rewoo agent * add LLMTool * update BaseAgent type * add ReAct agent * add ReAct agent * minor update * minor update * minor update * minor update * update base reader with BaseComponent * add splitter * update agent and tool * update vectorstores * update load/save for indexing and retrieving pipeline * update test_agent for more use-cases * add missing dependency for test * update test case for in memory vectorstore * add TextSplitter to BaseComponent * update type hint basetool --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
49ed3f6994
commit
56bc41b673
|
@ -1,6 +1,12 @@
|
|||
from .base import AutoReader
|
||||
from .base import AutoReader, DirectoryReader
|
||||
from .excel_loader import PandasExcelReader
|
||||
from .mathpix_loader import MathpixPDFReader
|
||||
from .ocr_loader import OCRReader
|
||||
|
||||
__all__ = ["AutoReader", "PandasExcelReader", "MathpixPDFReader", "OCRReader"]
|
||||
__all__ = [
|
||||
"AutoReader",
|
||||
"PandasExcelReader",
|
||||
"MathpixPDFReader",
|
||||
"OCRReader",
|
||||
"DirectoryReader",
|
||||
]
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from pathlib import Path
|
||||
from typing import Any, List, Type, Union
|
||||
|
||||
from llama_index import download_loader
|
||||
from llama_index import SimpleDirectoryReader, download_loader
|
||||
from llama_index.readers.base import BaseReader
|
||||
|
||||
from ..base import BaseComponent
|
||||
from ..documents.base import Document
|
||||
|
||||
|
||||
class AutoReader(BaseReader):
|
||||
class AutoReader(BaseComponent, BaseReader):
|
||||
"""General auto reader for a variety of files. (based on llama-hub)"""
|
||||
|
||||
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
|
||||
|
@ -17,6 +18,7 @@ class AutoReader(BaseReader):
|
|||
self._reader = download_loader(reader_type)()
|
||||
else:
|
||||
self._reader = reader_type()
|
||||
super().__init__()
|
||||
|
||||
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
||||
documents = self._reader.load_data(file=file, **kwargs)
|
||||
|
@ -24,3 +26,42 @@ class AutoReader(BaseReader):
|
|||
# convert Document to new base class from kotaemon
|
||||
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||
return converted_documents
|
||||
|
||||
def run(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
||||
return self.load_data(file=file, **kwargs)
|
||||
|
||||
|
||||
class LIBaseReader(BaseComponent, BaseReader):
|
||||
_reader_class: Type[BaseReader]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self._reader_class is None:
|
||||
raise AttributeError(
|
||||
"Require `_reader_class` to set a BaseReader class from LlamarIndex"
|
||||
)
|
||||
|
||||
self._reader = self._reader_class(*args, **kwargs)
|
||||
super().__init__()
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_"):
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
return setattr(self._reader, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._reader, name)
|
||||
|
||||
def load_data(self, *args, **kwargs: Any) -> List[Document]:
|
||||
documents = self._reader.load_data(*args, **kwargs)
|
||||
|
||||
# convert Document to new base class from kotaemon
|
||||
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||
return converted_documents
|
||||
|
||||
def run(self, *args, **kwargs: Any) -> List[Document]:
|
||||
return self.load_data(*args, **kwargs)
|
||||
|
||||
|
||||
class DirectoryReader(LIBaseReader):
|
||||
_reader_class = SimpleDirectoryReader
|
||||
|
|
0
knowledgehub/parsers/__init__.py
Normal file
0
knowledgehub/parsers/__init__.py
Normal file
65
knowledgehub/parsers/splitter.py
Normal file
65
knowledgehub/parsers/splitter.py
Normal file
|
@ -0,0 +1,65 @@
|
|||
from typing import Any, List, Sequence, Type
|
||||
|
||||
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
|
||||
from llama_index.node_parser.interface import NodeParser
|
||||
from llama_index.text_splitter import TokenTextSplitter
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
|
||||
from ..documents.base import Document
|
||||
|
||||
__all__ = ["TokenTextSplitter"]
|
||||
|
||||
|
||||
class LINodeParser(BaseComponent):
|
||||
_parser_class: Type[NodeParser]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if self._parser_class is None:
|
||||
raise AttributeError(
|
||||
"Require `_parser_class` to set a NodeParser class from LlamarIndex"
|
||||
)
|
||||
self._parser = self._parser_class(*args, **kwargs)
|
||||
super().__init__()
|
||||
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name.startswith("_") or name in self._protected_keywords():
|
||||
return super().__setattr__(name, value)
|
||||
|
||||
return setattr(self._parser, name, value)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
return getattr(self._parser, name)
|
||||
|
||||
def get_nodes_from_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
show_progress: bool = False,
|
||||
) -> List[Document]:
|
||||
documents = self._parser.get_nodes_from_documents(
|
||||
documents=documents, show_progress=show_progress
|
||||
)
|
||||
# convert Document to new base class from kotaemon
|
||||
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||
return converted_documents
|
||||
|
||||
def run(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
show_progress: bool = False,
|
||||
) -> List[Document]:
|
||||
return self.get_nodes_from_documents(
|
||||
documents=documents, show_progress=show_progress
|
||||
)
|
||||
|
||||
|
||||
class SimpleNodeParser(LINodeParser):
|
||||
_parser_class = LISimpleNodeParser
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
chunk_size = kwargs.pop("chunk_size", 512)
|
||||
chunk_overlap = kwargs.pop("chunk_overlap", 0)
|
||||
kwargs["text_splitter"] = TokenTextSplitter(
|
||||
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
|
@ -1,3 +1,5 @@
|
|||
from .base import BaseAgent
|
||||
from .react.agent import ReactAgent
|
||||
from .rewoo.agent import RewooAgent
|
||||
|
||||
__all__ = ["BaseAgent"]
|
||||
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent"]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import uuid
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
|
@ -9,6 +10,9 @@ from ..documents.base import Document
|
|||
from ..embeddings import BaseEmbeddings
|
||||
from ..vectorstores import BaseVectorStore
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
|
||||
|
||||
class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||
"""Ingest the document, run through the embedding, and store the embedding in a
|
||||
|
@ -20,7 +24,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
|||
"""
|
||||
|
||||
vector_store: Param[BaseVectorStore] = Param()
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
doc_store: Param[BaseDocumentStore] = Param()
|
||||
embedding: Node[BaseEmbeddings] = Node()
|
||||
|
||||
# TODO: refer to llama_index's storage as well
|
||||
|
@ -30,7 +34,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
|||
self.run_batch_document([document])
|
||||
|
||||
def run_batch_raw(self, text: List[str]) -> None:
|
||||
documents = [Document(t, id_=str(uuid.uuid4())) for t in text]
|
||||
documents = [Document(text=t, id_=str(uuid.uuid4())) for t in text]
|
||||
self.run_batch_document(documents)
|
||||
|
||||
def run_document(self, text: Document) -> None:
|
||||
|
@ -57,13 +61,31 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
|||
return True
|
||||
return False
|
||||
|
||||
def persist(self, path: str):
|
||||
def save(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Save the whole state of the indexing pipeline vector store and all
|
||||
necessary information to disk
|
||||
|
||||
Args:
|
||||
path (str): path to save the state
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.save(path / vectorstore_fname)
|
||||
self.doc_store.save(path / docstore_fname)
|
||||
|
||||
def load(self, path: str):
|
||||
def load(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Load all information from disk to an object"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.load(path / vectorstore_fname)
|
||||
self.doc_store.load(path / docstore_fname)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from abc import abstractmethod
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
|
@ -9,6 +10,9 @@ from ..documents.base import Document, RetrievedDocument
|
|||
from ..embeddings import BaseEmbeddings
|
||||
from ..vectorstores import BaseVectorStore
|
||||
|
||||
VECTOR_STORE_FNAME = "vectorstore"
|
||||
DOC_STORE_FNAME = "docstore"
|
||||
|
||||
|
||||
class BaseRetrieval(BaseComponent):
|
||||
"""Define the base interface of a retrieval pipeline"""
|
||||
|
@ -38,7 +42,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
|
|||
"""Retrieve list of documents from vector store"""
|
||||
|
||||
vector_store: Param[BaseVectorStore] = Param()
|
||||
doc_store: Optional[BaseDocumentStore] = None
|
||||
doc_store: Param[BaseDocumentStore] = Param()
|
||||
embedding: Node[BaseEmbeddings] = Node()
|
||||
# TODO: refer to llama_index's storage as well
|
||||
|
||||
|
@ -86,13 +90,31 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
|
|||
return True
|
||||
return False
|
||||
|
||||
def persist(self, path: str):
|
||||
def save(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Save the whole state of the indexing pipeline vector store and all
|
||||
necessary information to disk
|
||||
|
||||
Args:
|
||||
path (str): path to save the state
|
||||
"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.save(path / vectorstore_fname)
|
||||
self.doc_store.save(path / docstore_fname)
|
||||
|
||||
def load(self, path: str):
|
||||
def load(
|
||||
self,
|
||||
path: Union[str, Path],
|
||||
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||
docstore_fname: str = DOC_STORE_FNAME,
|
||||
):
|
||||
"""Load all information from disk to an object"""
|
||||
if isinstance(path, str):
|
||||
path = Path(path)
|
||||
self.vector_store.load(path / vectorstore_fname)
|
||||
self.doc_store.load(path / docstore_fname)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain.agents import Tool as LCTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
|
@ -87,6 +88,10 @@ class BaseTool(BaseComponent):
|
|||
)
|
||||
return observation
|
||||
|
||||
def to_langchain_format(self) -> LCTool:
|
||||
"""Convert this tool to Langchain format to use with its agent"""
|
||||
return LCTool(name=self.name, description=self.description, func=self.run)
|
||||
|
||||
def run_raw(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
|
@ -122,6 +127,15 @@ class BaseTool(BaseComponent):
|
|||
"""Tool does not support processing batch"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
|
||||
"""Wrapper for Langchain Tool"""
|
||||
new_tool = BaseTool(
|
||||
name=langchain_tool.name, description=langchain_tool.description
|
||||
)
|
||||
new_tool._run_tool = langchain_tool._run # type: ignore
|
||||
return new_tool
|
||||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
"""
|
||||
|
@ -130,6 +144,11 @@ class ComponentTool(BaseTool):
|
|||
"""
|
||||
|
||||
component: BaseComponent
|
||||
postprocessor: Optional[Callable] = None
|
||||
|
||||
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.component(*args, **kwargs)
|
||||
output = self.component(*args, **kwargs)
|
||||
if self.postprocessor:
|
||||
output = self.postprocessor(output)
|
||||
|
||||
return output
|
||||
|
|
|
@ -67,6 +67,9 @@ class ChromaVectorStore(LlamaIndexVectorStore):
|
|||
collection_name = self._client.client.name
|
||||
self._client.client._client.delete_collection(collection_name)
|
||||
|
||||
def count(self) -> int:
|
||||
return self._collection.count()
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
|
|
@ -44,9 +44,7 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
|
|||
"""
|
||||
self._client.persist(persist_path=save_path, fs=fs)
|
||||
|
||||
def load(
|
||||
self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
|
||||
) -> "InMemoryVectorStore":
|
||||
def load(self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None):
|
||||
|
||||
"""Create a SimpleKVStore from a load directory.
|
||||
|
||||
|
@ -54,4 +52,4 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
|
|||
load_path: Path of loading vector.
|
||||
fs: An abstract super-class for pythonic file-systems
|
||||
"""
|
||||
return self._client.from_persist_path(persist_path=load_path, fs=fs)
|
||||
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)
|
||||
|
|
1
setup.py
1
setup.py
|
@ -53,6 +53,7 @@ setuptools.setup(
|
|||
"openai",
|
||||
"chromadb",
|
||||
"wikipedia",
|
||||
"duckduckgo-search",
|
||||
"googlesearch-python",
|
||||
"python-dotenv",
|
||||
"pytest-mock",
|
||||
|
|
|
@ -1,11 +1,19 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||
from kotaemon.pipelines.agents.react import ReactAgent
|
||||
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
||||
from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool
|
||||
from kotaemon.pipelines.tools import (
|
||||
BaseTool,
|
||||
GoogleSearchTool,
|
||||
LLMTool,
|
||||
WikipediaTool,
|
||||
)
|
||||
|
||||
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
|
||||
|
||||
_openai_chat_completion_responses_rewoo = [
|
||||
{
|
||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||
|
@ -73,19 +81,61 @@ _openai_chat_completion_responses_react = [
|
|||
]
|
||||
]
|
||||
|
||||
_openai_chat_completion_responses_react_langchain_tool = [
|
||||
{
|
||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||
"object": "chat.completion",
|
||||
"created": 1692338378,
|
||||
"model": "gpt-35-turbo",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": text,
|
||||
},
|
||||
}
|
||||
],
|
||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||
}
|
||||
for text in [
|
||||
(
|
||||
"I don't have prior knowledge about Cinnamon AI company, "
|
||||
"so I should gather information about it.\n"
|
||||
"Action: Wikipedia\n"
|
||||
"Action Input: Cinnamon AI company\n"
|
||||
),
|
||||
(
|
||||
"The information retrieved from Wikipedia is not "
|
||||
"about Cinnamon AI company, but about Blue Prism, "
|
||||
"a British multinational software corporation. "
|
||||
"I need to try another source to gather information "
|
||||
"about Cinnamon AI company.\n"
|
||||
"Action: duckduckgo_search\n"
|
||||
"Action Input: Cinnamon AI company\n"
|
||||
),
|
||||
FINAL_RESPONSE_TEXT,
|
||||
]
|
||||
]
|
||||
|
||||
@patch(
|
||||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||
side_effect=_openai_chat_completion_responses_rewoo,
|
||||
)
|
||||
def test_rewoo_agent(openai_completion):
|
||||
llm = AzureChatOpenAI(
|
||||
|
||||
@pytest.fixture
|
||||
def llm():
|
||||
return AzureChatOpenAI(
|
||||
openai_api_base="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||
side_effect=_openai_chat_completion_responses_rewoo,
|
||||
)
|
||||
def test_rewoo_agent(openai_completion, llm):
|
||||
plugins = [
|
||||
GoogleSearchTool(),
|
||||
WikipediaTool(),
|
||||
|
@ -103,14 +153,7 @@ def test_rewoo_agent(openai_completion):
|
|||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||
side_effect=_openai_chat_completion_responses_react,
|
||||
)
|
||||
def test_react_agent(openai_completion):
|
||||
llm = AzureChatOpenAI(
|
||||
openai_api_base="https://dummy.openai.azure.com/",
|
||||
openai_api_key="dummy",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2",
|
||||
temperature=0,
|
||||
)
|
||||
def test_react_agent(openai_completion, llm):
|
||||
plugins = [
|
||||
GoogleSearchTool(),
|
||||
WikipediaTool(),
|
||||
|
@ -121,3 +164,47 @@ def test_react_agent(openai_completion):
|
|||
response = agent("Tell me about Cinnamon AI company")
|
||||
openai_completion.assert_called()
|
||||
assert response.output == FINAL_RESPONSE_TEXT
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||
side_effect=_openai_chat_completion_responses_react,
|
||||
)
|
||||
def test_react_agent_langchain(openai_completion, llm):
|
||||
from langchain.agents import AgentType, initialize_agent
|
||||
|
||||
plugins = [
|
||||
GoogleSearchTool(),
|
||||
WikipediaTool(),
|
||||
LLMTool(llm=llm),
|
||||
]
|
||||
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
|
||||
agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
llm.agent,
|
||||
agent=AgentType.OPENAI_FUNCTIONS,
|
||||
verbose=True,
|
||||
)
|
||||
response = agent("Tell me about Cinnamon AI company")
|
||||
openai_completion.assert_called()
|
||||
assert response
|
||||
|
||||
|
||||
@patch(
|
||||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||
side_effect=_openai_chat_completion_responses_react_langchain_tool,
|
||||
)
|
||||
def test_react_agent_with_langchain_tools(openai_completion, llm):
|
||||
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
|
||||
from langchain.utilities import WikipediaAPIWrapper
|
||||
|
||||
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||
search = DuckDuckGoSearchRun()
|
||||
|
||||
langchain_plugins = [wikipedia, search]
|
||||
plugins = [BaseTool.from_langchain_format(tool) for tool in langchain_plugins]
|
||||
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
|
||||
|
||||
response = agent("Tell me about Cinnamon AI company")
|
||||
openai_completion.assert_called()
|
||||
assert response.output == FINAL_RESPONSE_TEXT
|
||||
|
|
|
@ -116,8 +116,8 @@ class TestInMemoryVectorStore:
|
|||
"3" not in data["text_id_to_ref_doc_id"]
|
||||
), "delete function does not delete data completely"
|
||||
db2 = InMemoryVectorStore()
|
||||
output = db2.load(load_path=tmp_path / "test_save_load_delete.json")
|
||||
assert output.get("2") == [
|
||||
db2.load(load_path=tmp_path / "test_save_load_delete.json")
|
||||
assert db2.get("2") == [
|
||||
0.4,
|
||||
0.5,
|
||||
0.6,
|
||||
|
|
Loading…
Reference in New Issue
Block a user