Update Base interface of Index/Retrieval pipeline (#36)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

* add base Agent Interface, add ReWoo Agent

* minor update

* update test

* fix typo

* remove unneeded print

* update rewoo agent

* add LLMTool

* update BaseAgent type

* add ReAct agent

* add ReAct agent

* minor update

* minor update

* minor update

* minor update

* update base reader with BaseComponent

* add splitter

* update agent and tool

* update vectorstores

* update load/save for indexing and retrieving pipeline

* update test_agent for more use-cases

* add missing dependency for test

* update test case for in memory vectorstore

* add TextSplitter to BaseComponent

* update type hint basetool

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-10-04 14:27:44 +07:00 committed by GitHub
parent 49ed3f6994
commit 56bc41b673
13 changed files with 302 additions and 36 deletions

View File

@ -1,6 +1,12 @@
from .base import AutoReader
from .base import AutoReader, DirectoryReader
from .excel_loader import PandasExcelReader
from .mathpix_loader import MathpixPDFReader
from .ocr_loader import OCRReader
__all__ = ["AutoReader", "PandasExcelReader", "MathpixPDFReader", "OCRReader"]
__all__ = [
"AutoReader",
"PandasExcelReader",
"MathpixPDFReader",
"OCRReader",
"DirectoryReader",
]

View File

@ -1,13 +1,14 @@
from pathlib import Path
from typing import Any, List, Type, Union
from llama_index import download_loader
from llama_index import SimpleDirectoryReader, download_loader
from llama_index.readers.base import BaseReader
from ..base import BaseComponent
from ..documents.base import Document
class AutoReader(BaseReader):
class AutoReader(BaseComponent, BaseReader):
"""General auto reader for a variety of files. (based on llama-hub)"""
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
@ -17,6 +18,7 @@ class AutoReader(BaseReader):
self._reader = download_loader(reader_type)()
else:
self._reader = reader_type()
super().__init__()
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
documents = self._reader.load_data(file=file, **kwargs)
@ -24,3 +26,42 @@ class AutoReader(BaseReader):
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
return self.load_data(file=file, **kwargs)
class LIBaseReader(BaseComponent, BaseReader):
_reader_class: Type[BaseReader]
def __init__(self, *args, **kwargs):
if self._reader_class is None:
raise AttributeError(
"Require `_reader_class` to set a BaseReader class from LlamarIndex"
)
self._reader = self._reader_class(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_"):
return super().__setattr__(name, value)
return setattr(self._reader, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._reader, name)
def load_data(self, *args, **kwargs: Any) -> List[Document]:
documents = self._reader.load_data(*args, **kwargs)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(self, *args, **kwargs: Any) -> List[Document]:
return self.load_data(*args, **kwargs)
class DirectoryReader(LIBaseReader):
_reader_class = SimpleDirectoryReader

View File

View File

@ -0,0 +1,65 @@
from typing import Any, List, Sequence, Type
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
from llama_index.node_parser.interface import NodeParser
from llama_index.text_splitter import TokenTextSplitter
from kotaemon.base import BaseComponent
from ..documents.base import Document
__all__ = ["TokenTextSplitter"]
class LINodeParser(BaseComponent):
_parser_class: Type[NodeParser]
def __init__(self, *args, **kwargs):
if self._parser_class is None:
raise AttributeError(
"Require `_parser_class` to set a NodeParser class from LlamarIndex"
)
self._parser = self._parser_class(*args, **kwargs)
super().__init__()
def __setattr__(self, name: str, value: Any) -> None:
if name.startswith("_") or name in self._protected_keywords():
return super().__setattr__(name, value)
return setattr(self._parser, name, value)
def __getattr__(self, name: str) -> Any:
return getattr(self._parser, name)
def get_nodes_from_documents(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[Document]:
documents = self._parser.get_nodes_from_documents(
documents=documents, show_progress=show_progress
)
# convert Document to new base class from kotaemon
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
return converted_documents
def run(
self,
documents: Sequence[Document],
show_progress: bool = False,
) -> List[Document]:
return self.get_nodes_from_documents(
documents=documents, show_progress=show_progress
)
class SimpleNodeParser(LINodeParser):
_parser_class = LISimpleNodeParser
def __init__(self, *args, **kwargs):
chunk_size = kwargs.pop("chunk_size", 512)
chunk_overlap = kwargs.pop("chunk_overlap", 0)
kwargs["text_splitter"] = TokenTextSplitter(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
)
super().__init__(*args, **kwargs)

View File

@ -1,3 +1,5 @@
from .base import BaseAgent
from .react.agent import ReactAgent
from .rewoo.agent import RewooAgent
__all__ = ["BaseAgent"]
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent"]

View File

@ -1,5 +1,6 @@
import uuid
from typing import List, Optional
from pathlib import Path
from typing import List, Union
from theflow import Node, Param
@ -9,6 +10,9 @@ from ..documents.base import Document
from ..embeddings import BaseEmbeddings
from ..vectorstores import BaseVectorStore
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
class IndexVectorStoreFromDocumentPipeline(BaseComponent):
"""Ingest the document, run through the embedding, and store the embedding in a
@ -20,7 +24,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
"""
vector_store: Param[BaseVectorStore] = Param()
doc_store: Optional[BaseDocumentStore] = None
doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node()
# TODO: refer to llama_index's storage as well
@ -30,7 +34,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
self.run_batch_document([document])
def run_batch_raw(self, text: List[str]) -> None:
documents = [Document(t, id_=str(uuid.uuid4())) for t in text]
documents = [Document(text=t, id_=str(uuid.uuid4())) for t in text]
self.run_batch_document(documents)
def run_document(self, text: Document) -> None:
@ -57,13 +61,31 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
return True
return False
def persist(self, path: str):
def save(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Save the whole state of the indexing pipeline vector store and all
necessary information to disk
Args:
path (str): path to save the state
"""
if isinstance(path, str):
path = Path(path)
self.vector_store.save(path / vectorstore_fname)
self.doc_store.save(path / docstore_fname)
def load(self, path: str):
def load(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Load all information from disk to an object"""
if isinstance(path, str):
path = Path(path)
self.vector_store.load(path / vectorstore_fname)
self.doc_store.load(path / docstore_fname)

View File

@ -1,5 +1,6 @@
from abc import abstractmethod
from typing import List, Optional
from pathlib import Path
from typing import List, Union
from theflow import Node, Param
@ -9,6 +10,9 @@ from ..documents.base import Document, RetrievedDocument
from ..embeddings import BaseEmbeddings
from ..vectorstores import BaseVectorStore
VECTOR_STORE_FNAME = "vectorstore"
DOC_STORE_FNAME = "docstore"
class BaseRetrieval(BaseComponent):
"""Define the base interface of a retrieval pipeline"""
@ -38,7 +42,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
"""Retrieve list of documents from vector store"""
vector_store: Param[BaseVectorStore] = Param()
doc_store: Optional[BaseDocumentStore] = None
doc_store: Param[BaseDocumentStore] = Param()
embedding: Node[BaseEmbeddings] = Node()
# TODO: refer to llama_index's storage as well
@ -86,13 +90,31 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
return True
return False
def persist(self, path: str):
def save(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Save the whole state of the indexing pipeline vector store and all
necessary information to disk
Args:
path (str): path to save the state
"""
if isinstance(path, str):
path = Path(path)
self.vector_store.save(path / vectorstore_fname)
self.doc_store.save(path / docstore_fname)
def load(self, path: str):
def load(
self,
path: Union[str, Path],
vectorstore_fname: str = VECTOR_STORE_FNAME,
docstore_fname: str = DOC_STORE_FNAME,
):
"""Load all information from disk to an object"""
if isinstance(path, str):
path = Path(path)
self.vector_store.load(path / vectorstore_fname)
self.doc_store.load(path / docstore_fname)

View File

@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from langchain.agents import Tool as LCTool
from pydantic import BaseModel
from kotaemon.base import BaseComponent
@ -87,6 +88,10 @@ class BaseTool(BaseComponent):
)
return observation
def to_langchain_format(self) -> LCTool:
"""Convert this tool to Langchain format to use with its agent"""
return LCTool(name=self.name, description=self.description, func=self.run)
def run_raw(
self,
tool_input: Union[str, Dict],
@ -122,6 +127,15 @@ class BaseTool(BaseComponent):
"""Tool does not support processing batch"""
return False
@classmethod
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
"""Wrapper for Langchain Tool"""
new_tool = BaseTool(
name=langchain_tool.name, description=langchain_tool.description
)
new_tool._run_tool = langchain_tool._run # type: ignore
return new_tool
class ComponentTool(BaseTool):
"""
@ -130,6 +144,11 @@ class ComponentTool(BaseTool):
"""
component: BaseComponent
postprocessor: Optional[Callable] = None
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
return self.component(*args, **kwargs)
output = self.component(*args, **kwargs)
if self.postprocessor:
output = self.postprocessor(output)
return output

View File

@ -67,6 +67,9 @@ class ChromaVectorStore(LlamaIndexVectorStore):
collection_name = self._client.client.name
self._client.client._client.delete_collection(collection_name)
def count(self) -> int:
return self._collection.count()
def save(self, *args, **kwargs):
pass

View File

@ -44,9 +44,7 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
"""
self._client.persist(persist_path=save_path, fs=fs)
def load(
self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
) -> "InMemoryVectorStore":
def load(self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None):
"""Create a SimpleKVStore from a load directory.
@ -54,4 +52,4 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
load_path: Path of loading vector.
fs: An abstract super-class for pythonic file-systems
"""
return self._client.from_persist_path(persist_path=load_path, fs=fs)
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)

View File

@ -53,6 +53,7 @@ setuptools.setup(
"openai",
"chromadb",
"wikipedia",
"duckduckgo-search",
"googlesearch-python",
"python-dotenv",
"pytest-mock",

View File

@ -1,11 +1,19 @@
from unittest.mock import patch
import pytest
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.react import ReactAgent
from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.pipelines.tools import (
BaseTool,
GoogleSearchTool,
LLMTool,
WikipediaTool,
)
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
_openai_chat_completion_responses_rewoo = [
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
@ -73,19 +81,61 @@ _openai_chat_completion_responses_react = [
]
]
_openai_chat_completion_responses_react_langchain_tool = [
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
"so I should gather information about it.\n"
"Action: Wikipedia\n"
"Action Input: Cinnamon AI company\n"
),
(
"The information retrieved from Wikipedia is not "
"about Cinnamon AI company, but about Blue Prism, "
"a British multinational software corporation. "
"I need to try another source to gather information "
"about Cinnamon AI company.\n"
"Action: duckduckgo_search\n"
"Action Input: Cinnamon AI company\n"
),
FINAL_RESPONSE_TEXT,
]
]
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_rewoo,
)
def test_rewoo_agent(openai_completion):
llm = AzureChatOpenAI(
@pytest.fixture
def llm():
return AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_rewoo,
)
def test_rewoo_agent(openai_completion, llm):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
@ -103,14 +153,7 @@ def test_rewoo_agent(openai_completion):
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent(openai_completion):
llm = AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
def test_react_agent(openai_completion, llm):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
@ -121,3 +164,47 @@ def test_react_agent(openai_completion):
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent_langchain(openai_completion, llm):
from langchain.agents import AgentType, initialize_agent
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent(
langchain_plugins,
llm.agent,
agent=AgentType.OPENAI_FUNCTIONS,
verbose=True,
)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react_langchain_tool,
)
def test_react_agent_with_langchain_tools(openai_completion, llm):
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
from langchain.utilities import WikipediaAPIWrapper
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
search = DuckDuckGoSearchRun()
langchain_plugins = [wikipedia, search]
plugins = [BaseTool.from_langchain_format(tool) for tool in langchain_plugins]
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT

View File

@ -116,8 +116,8 @@ class TestInMemoryVectorStore:
"3" not in data["text_id_to_ref_doc_id"]
), "delete function does not delete data completely"
db2 = InMemoryVectorStore()
output = db2.load(load_path=tmp_path / "test_save_load_delete.json")
assert output.get("2") == [
db2.load(load_path=tmp_path / "test_save_load_delete.json")
assert db2.get("2") == [
0.4,
0.5,
0.6,