Update Base interface of Index/Retrieval pipeline (#36)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

* add base Agent Interface, add ReWoo Agent

* minor update

* update test

* fix typo

* remove unneeded print

* update rewoo agent

* add LLMTool

* update BaseAgent type

* add ReAct agent

* add ReAct agent

* minor update

* minor update

* minor update

* minor update

* update base reader with BaseComponent

* add splitter

* update agent and tool

* update vectorstores

* update load/save for indexing and retrieving pipeline

* update test_agent for more use-cases

* add missing dependency for test

* update test case for in memory vectorstore

* add TextSplitter to BaseComponent

* update type hint basetool

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-10-04 14:27:44 +07:00 committed by GitHub
parent 49ed3f6994
commit 56bc41b673
13 changed files with 302 additions and 36 deletions

View File

@ -1,6 +1,12 @@
from .base import AutoReader from .base import AutoReader, DirectoryReader
from .excel_loader import PandasExcelReader from .excel_loader import PandasExcelReader
from .mathpix_loader import MathpixPDFReader from .mathpix_loader import MathpixPDFReader
from .ocr_loader import OCRReader from .ocr_loader import OCRReader
__all__ = ["AutoReader", "PandasExcelReader", "MathpixPDFReader", "OCRReader"] __all__ = [
"AutoReader",
"PandasExcelReader",
"MathpixPDFReader",
"OCRReader",
"DirectoryReader",
]

View File

@ -1,13 +1,14 @@
from pathlib import Path from pathlib import Path
from typing import Any, List, Type, Union from typing import Any, List, Type, Union
from llama_index import download_loader from llama_index import SimpleDirectoryReader, download_loader
from llama_index.readers.base import BaseReader from llama_index.readers.base import BaseReader
from ..base import BaseComponent
from ..documents.base import Document from ..documents.base import Document
class AutoReader(BaseReader): class AutoReader(BaseComponent, BaseReader):
"""General auto reader for a variety of files. (based on llama-hub)""" """General auto reader for a variety of files. (based on llama-hub)"""
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None: def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
@ -17,6 +18,7 @@ class AutoReader(BaseReader):
self._reader = download_loader(reader_type)() self._reader = download_loader(reader_type)()
else: else:
self._reader = reader_type() self._reader = reader_type()
super().__init__()
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]: def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
documents = self._reader.load_data(file=file, **kwargs) documents = self._reader.load_data(file=file, **kwargs)
@ -24,3 +26,42 @@ class AutoReader(BaseReader):
# convert Document to new base class from kotaemon # convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents] converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents return converted_documents
def run(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
return self.load_data(file=file, **kwargs)
class LIBaseReader(BaseComponent, BaseReader):
_reader_class: Type[BaseReader]
def __init__(self, *args, **kwargs):
if self._reader_class is None:
raise AttributeError(
"Require `_reader_class` to set a BaseReader class from LlamarIndex"
)
self._reader = self._reader_class(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
return super().__setattr__(name, value)
return setattr(self._reader, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._reader, name)
def load_data(self, *args, **kwargs: Any) -> List[Document]:
documents = self._reader.load_data(*args, **kwargs)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(self, *args, **kwargs: Any) -> List[Document]:
return self.load_data(*args, **kwargs)
class DirectoryReader(LIBaseReader):
_reader_class = SimpleDirectoryReader

View File

View File

@ -0,0 +1,65 @@
from typing import Any, List, Sequence, Type
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
from llama_index.node_parser.interface import NodeParser
from llama_index.text_splitter import TokenTextSplitter
from kotaemon.base import BaseComponent
from ..documents.base import Document
__all__ = ["TokenTextSplitter"]
class LINodeParser(BaseComponent):
_parser_class: Type[NodeParser]
def __init__(self, *args, **kwargs):
if self._parser_class is None:
raise AttributeError(
"Require `_parser_class` to set a NodeParser class from LlamarIndex"
)
self._parser = self._parser_class(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
return setattr(self._parser, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._parser, name)
def get_nodes_from_documents(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[Document]:
documents = self._parser.get_nodes_from_documents(
documents=documents, show_progress=show_progress
)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[Document]:
return self.get_nodes_from_documents(
documents=documents, show_progress=show_progress
)
class SimpleNodeParser(LINodeParser):
_parser_class = LISimpleNodeParser
def __init__(self, *args, **kwargs):
chunk_size = kwargs.pop("chunk_size", 512)
chunk_overlap = kwargs.pop("chunk_overlap", 0)
kwargs["text_splitter"] = TokenTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
super().__init__(*args, **kwargs)

View File

@ -1,3 +1,5 @@
from .base import BaseAgent from .base import BaseAgent
from .react.agent import ReactAgent
from .rewoo.agent import RewooAgent
__all__ = ["BaseAgent"] __all__ = ["BaseAgent", "ReactAgent", "RewooAgent"]

View File

@ -1,5 +1,6 @@
import uuid import uuid
from typing import List, Optional from pathlib import Path
from typing import List, Union
from theflow import Node, Param from theflow import Node, Param
@ -9,6 +10,9 @@ from ..documents.base import Document
from ..embeddings import BaseEmbeddings from ..embeddings import BaseEmbeddings
from ..vectorstores import BaseVectorStore from ..vectorstores import BaseVectorStore
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
class IndexVectorStoreFromDocumentPipeline(BaseComponent): class IndexVectorStoreFromDocumentPipeline(BaseComponent):
"""Ingest the document, run through the embedding, and store the embedding in a """Ingest the document, run through the embedding, and store the embedding in a
@ -20,7 +24,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
""" """
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
doc_store: Optional[BaseDocumentStore] = None doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
@ -30,7 +34,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
self.run_batch_document([document]) self.run_batch_document([document])
def run_batch_raw(self, text: List[str]) -> None: def run_batch_raw(self, text: List[str]) -> None:
documents = [Document(t, id_=str(uuid.uuid4())) for t in text] documents = [Document(text=t, id_=str(uuid.uuid4())) for t in text]
self.run_batch_document(documents) self.run_batch_document(documents)
def run_document(self, text: Document) -> None: def run_document(self, text: Document) -> None:
@ -57,13 +61,31 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
return True return True
return False return False
def persist(self, path: str): def save(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Save the whole state of the indexing pipeline vector store and all """Save the whole state of the indexing pipeline vector store and all
necessary information to disk necessary information to disk
Args: Args:
path (str): path to save the state path (str): path to save the state
""" """
if isinstance(path, str):
path = Path(path)
self.vector_store.save(path / vectorstore_fname)
self.doc_store.save(path / docstore_fname)
def load(self, path: str): def load(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Load all information from disk to an object""" """Load all information from disk to an object"""
if isinstance(path, str):
path = Path(path)
self.vector_store.load(path / vectorstore_fname)
self.doc_store.load(path / docstore_fname)

View File

@ -1,5 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional from pathlib import Path
from typing import List, Union
from theflow import Node, Param from theflow import Node, Param
@ -9,6 +10,9 @@ from ..documents.base import Document, RetrievedDocument
from ..embeddings import BaseEmbeddings from ..embeddings import BaseEmbeddings
from ..vectorstores import BaseVectorStore from ..vectorstores import BaseVectorStore
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
class BaseRetrieval(BaseComponent): class BaseRetrieval(BaseComponent):
"""Define the base interface of a retrieval pipeline""" """Define the base interface of a retrieval pipeline"""
@ -38,7 +42,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
"""Retrieve list of documents from vector store""" """Retrieve list of documents from vector store"""
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
doc_store: Optional[BaseDocumentStore] = None doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
@ -86,13 +90,31 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
return True return True
return False return False
def persist(self, path: str): def save(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Save the whole state of the indexing pipeline vector store and all """Save the whole state of the indexing pipeline vector store and all
necessary information to disk necessary information to disk
Args: Args:
path (str): path to save the state path (str): path to save the state
""" """
if isinstance(path, str):
path = Path(path)
self.vector_store.save(path / vectorstore_fname)
self.doc_store.save(path / docstore_fname)
def load(self, path: str): def load(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Load all information from disk to an object""" """Load all information from disk to an object"""
if isinstance(path, str):
path = Path(path)
self.vector_store.load(path / vectorstore_fname)
self.doc_store.load(path / docstore_fname)

View File

@ -1,6 +1,7 @@
from abc import abstractmethod from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from langchain.agents import Tool as LCTool
from pydantic import BaseModel from pydantic import BaseModel
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
@ -87,6 +88,10 @@ class BaseTool(BaseComponent):
) )
return observation return observation
def to_langchain_format(self) -> LCTool:
"""Convert this tool to Langchain format to use with its agent"""
return LCTool(name=self.name, description=self.description, func=self.run)
def run_raw( def run_raw(
self, self,
tool_input: Union[str, Dict], tool_input: Union[str, Dict],
@ -122,6 +127,15 @@ class BaseTool(BaseComponent):
"""Tool does not support processing batch""" """Tool does not support processing batch"""
return False return False
@classmethod
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
"""Wrapper for Langchain Tool"""
new_tool = BaseTool(
name=langchain_tool.name, description=langchain_tool.description
)
new_tool._run_tool = langchain_tool._run # type: ignore
return new_tool
class ComponentTool(BaseTool): class ComponentTool(BaseTool):
""" """
@ -130,6 +144,11 @@ class ComponentTool(BaseTool):
""" """
component: BaseComponent component: BaseComponent
postprocessor: Optional[Callable] = None
def _run_tool(self, *args: Any, **kwargs: Any) -> Any: def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
return self.component(*args, **kwargs) output = self.component(*args, **kwargs)
if self.postprocessor:
output = self.postprocessor(output)
return output

View File

@ -67,6 +67,9 @@ class ChromaVectorStore(LlamaIndexVectorStore):
collection_name = self._client.client.name collection_name = self._client.client.name
self._client.client._client.delete_collection(collection_name) self._client.client._client.delete_collection(collection_name)
def count(self) -> int:
return self._collection.count()
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
pass pass

View File

@ -44,9 +44,7 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
""" """
self._client.persist(persist_path=save_path, fs=fs) self._client.persist(persist_path=save_path, fs=fs)
def load( def load(self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None):
self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
) -> "InMemoryVectorStore":
"""Create a SimpleKVStore from a load directory. """Create a SimpleKVStore from a load directory.
@ -54,4 +52,4 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
load_path: Path of loading vector. load_path: Path of loading vector.
fs: An abstract super-class for pythonic file-systems fs: An abstract super-class for pythonic file-systems
""" """
return self._client.from_persist_path(persist_path=load_path, fs=fs) self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)

View File

@ -53,6 +53,7 @@ setuptools.setup(
"openai", "openai",
"chromadb", "chromadb",
"wikipedia", "wikipedia",
"duckduckgo-search",
"googlesearch-python", "googlesearch-python",
"python-dotenv", "python-dotenv",
"pytest-mock", "pytest-mock",

View File

@ -1,11 +1,19 @@
from unittest.mock import patch from unittest.mock import patch
import pytest
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.react import ReactAgent from kotaemon.pipelines.agents.react import ReactAgent
from kotaemon.pipelines.agents.rewoo import RewooAgent from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool from kotaemon.pipelines.tools import (
BaseTool,
GoogleSearchTool,
LLMTool,
WikipediaTool,
)
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!" FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
_openai_chat_completion_responses_rewoo = [ _openai_chat_completion_responses_rewoo = [
{ {
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
@ -73,19 +81,61 @@ _openai_chat_completion_responses_react = [
] ]
] ]
_openai_chat_completion_responses_react_langchain_tool = [
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
"so I should gather information about it.\n"
"Action: Wikipedia\n"
"Action Input: Cinnamon AI company\n"
),
(
"The information retrieved from Wikipedia is not "
"about Cinnamon AI company, but about Blue Prism, "
"a British multinational software corporation. "
"I need to try another source to gather information "
"about Cinnamon AI company.\n"
"Action: duckduckgo_search\n"
"Action Input: Cinnamon AI company\n"
),
FINAL_RESPONSE_TEXT,
]
]
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create", @pytest.fixture
side_effect=_openai_chat_completion_responses_rewoo, def llm():
) return AzureChatOpenAI(
def test_rewoo_agent(openai_completion):
llm = AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/", openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy", openai_api_key="dummy",
openai_api_version="2023-03-15-preview", openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2", deployment_name="dummy-q2",
temperature=0, temperature=0,
) )
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_rewoo,
)
def test_rewoo_agent(openai_completion, llm):
plugins = [ plugins = [
GoogleSearchTool(), GoogleSearchTool(),
WikipediaTool(), WikipediaTool(),
@ -103,14 +153,7 @@ def test_rewoo_agent(openai_completion):
"openai.api_resources.chat_completion.ChatCompletion.create", "openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react, side_effect=_openai_chat_completion_responses_react,
) )
def test_react_agent(openai_completion): def test_react_agent(openai_completion, llm):
llm = AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
plugins = [ plugins = [
GoogleSearchTool(), GoogleSearchTool(),
WikipediaTool(), WikipediaTool(),
@ -121,3 +164,47 @@ def test_react_agent(openai_completion):
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called() openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT assert response.output == FINAL_RESPONSE_TEXT
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent_langchain(openai_completion, llm):
from langchain.agents import AgentType, initialize_agent
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent(
langchain_plugins,
llm.agent,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True,
)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react_langchain_tool,
)
def test_react_agent_with_langchain_tools(openai_completion, llm):
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain.utilities import WikipediaAPIWrapper
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
search = DuckDuckGoSearchRun()
langchain_plugins = [wikipedia, search]
plugins = [BaseTool.from_langchain_format(tool) for tool in langchain_plugins]
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT

View File

@ -116,8 +116,8 @@ class TestInMemoryVectorStore:
"3" not in data["text_id_to_ref_doc_id"] "3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely" ), "delete function does not delete data completely"
db2 = InMemoryVectorStore() db2 = InMemoryVectorStore()
output = db2.load(load_path=tmp_path / "test_save_load_delete.json") db2.load(load_path=tmp_path / "test_save_load_delete.json")
assert output.get("2") == [ assert db2.get("2") == [
0.4, 0.4,
0.5, 0.5,
0.6, 0.6,