Update Base interface of Index/Retrieval pipeline (#36)
* add base Tool * minor update test_tool * update test dependency * update test dependency * Fix namespace conflict * update test * add base Agent Interface, add ReWoo Agent * minor update * update test * fix typo * remove unneeded print * update rewoo agent * add LLMTool * update BaseAgent type * add ReAct agent * add ReAct agent * minor update * minor update * minor update * minor update * update base reader with BaseComponent * add splitter * update agent and tool * update vectorstores * update load/save for indexing and retrieving pipeline * update test_agent for more use-cases * add missing dependency for test * update test case for in memory vectorstore * add TextSplitter to BaseComponent * update type hint basetool --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
49ed3f6994
commit
56bc41b673
|
@ -1,6 +1,12 @@
|
||||||
from .base import AutoReader
|
from .base import AutoReader, DirectoryReader
|
||||||
from .excel_loader import PandasExcelReader
|
from .excel_loader import PandasExcelReader
|
||||||
from .mathpix_loader import MathpixPDFReader
|
from .mathpix_loader import MathpixPDFReader
|
||||||
from .ocr_loader import OCRReader
|
from .ocr_loader import OCRReader
|
||||||
|
|
||||||
__all__ = ["AutoReader", "PandasExcelReader", "MathpixPDFReader", "OCRReader"]
|
__all__ = [
|
||||||
|
"AutoReader",
|
||||||
|
"PandasExcelReader",
|
||||||
|
"MathpixPDFReader",
|
||||||
|
"OCRReader",
|
||||||
|
"DirectoryReader",
|
||||||
|
]
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Type, Union
|
from typing import Any, List, Type, Union
|
||||||
|
|
||||||
from llama_index import download_loader
|
from llama_index import SimpleDirectoryReader, download_loader
|
||||||
from llama_index.readers.base import BaseReader
|
from llama_index.readers.base import BaseReader
|
||||||
|
|
||||||
|
from ..base import BaseComponent
|
||||||
from ..documents.base import Document
|
from ..documents.base import Document
|
||||||
|
|
||||||
|
|
||||||
class AutoReader(BaseReader):
|
class AutoReader(BaseComponent, BaseReader):
|
||||||
"""General auto reader for a variety of files. (based on llama-hub)"""
|
"""General auto reader for a variety of files. (based on llama-hub)"""
|
||||||
|
|
||||||
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
|
def __init__(self, reader_type: Union[str, Type[BaseReader]]) -> None:
|
||||||
|
@ -17,6 +18,7 @@ class AutoReader(BaseReader):
|
||||||
self._reader = download_loader(reader_type)()
|
self._reader = download_loader(reader_type)()
|
||||||
else:
|
else:
|
||||||
self._reader = reader_type()
|
self._reader = reader_type()
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
def load_data(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
||||||
documents = self._reader.load_data(file=file, **kwargs)
|
documents = self._reader.load_data(file=file, **kwargs)
|
||||||
|
@ -24,3 +26,42 @@ class AutoReader(BaseReader):
|
||||||
# convert Document to new base class from kotaemon
|
# convert Document to new base class from kotaemon
|
||||||
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||||
return converted_documents
|
return converted_documents
|
||||||
|
|
||||||
|
def run(self, file: Union[Path, str], **kwargs: Any) -> List[Document]:
|
||||||
|
return self.load_data(file=file, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class LIBaseReader(BaseComponent, BaseReader):
|
||||||
|
_reader_class: Type[BaseReader]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if self._reader_class is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Require `_reader_class` to set a BaseReader class from LlamarIndex"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._reader = self._reader_class(*args, **kwargs)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if name.startswith("_"):
|
||||||
|
return super().__setattr__(name, value)
|
||||||
|
|
||||||
|
return setattr(self._reader, name, value)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return getattr(self._reader, name)
|
||||||
|
|
||||||
|
def load_data(self, *args, **kwargs: Any) -> List[Document]:
|
||||||
|
documents = self._reader.load_data(*args, **kwargs)
|
||||||
|
|
||||||
|
# convert Document to new base class from kotaemon
|
||||||
|
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||||
|
return converted_documents
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs: Any) -> List[Document]:
|
||||||
|
return self.load_data(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class DirectoryReader(LIBaseReader):
|
||||||
|
_reader_class = SimpleDirectoryReader
|
||||||
|
|
0
knowledgehub/parsers/__init__.py
Normal file
0
knowledgehub/parsers/__init__.py
Normal file
65
knowledgehub/parsers/splitter.py
Normal file
65
knowledgehub/parsers/splitter.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from typing import Any, List, Sequence, Type
|
||||||
|
|
||||||
|
from llama_index.node_parser import SimpleNodeParser as LISimpleNodeParser
|
||||||
|
from llama_index.node_parser.interface import NodeParser
|
||||||
|
from llama_index.text_splitter import TokenTextSplitter
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
from ..documents.base import Document
|
||||||
|
|
||||||
|
__all__ = ["TokenTextSplitter"]
|
||||||
|
|
||||||
|
|
||||||
|
class LINodeParser(BaseComponent):
|
||||||
|
_parser_class: Type[NodeParser]
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
if self._parser_class is None:
|
||||||
|
raise AttributeError(
|
||||||
|
"Require `_parser_class` to set a NodeParser class from LlamarIndex"
|
||||||
|
)
|
||||||
|
self._parser = self._parser_class(*args, **kwargs)
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
if name.startswith("_") or name in self._protected_keywords():
|
||||||
|
return super().__setattr__(name, value)
|
||||||
|
|
||||||
|
return setattr(self._parser, name, value)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
return getattr(self._parser, name)
|
||||||
|
|
||||||
|
def get_nodes_from_documents(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Document],
|
||||||
|
show_progress: bool = False,
|
||||||
|
) -> List[Document]:
|
||||||
|
documents = self._parser.get_nodes_from_documents(
|
||||||
|
documents=documents, show_progress=show_progress
|
||||||
|
)
|
||||||
|
# convert Document to new base class from kotaemon
|
||||||
|
converted_documents = [Document.from_dict(doc.to_dict()) for doc in documents]
|
||||||
|
return converted_documents
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
documents: Sequence[Document],
|
||||||
|
show_progress: bool = False,
|
||||||
|
) -> List[Document]:
|
||||||
|
return self.get_nodes_from_documents(
|
||||||
|
documents=documents, show_progress=show_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleNodeParser(LINodeParser):
|
||||||
|
_parser_class = LISimpleNodeParser
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
chunk_size = kwargs.pop("chunk_size", 512)
|
||||||
|
chunk_overlap = kwargs.pop("chunk_overlap", 0)
|
||||||
|
kwargs["text_splitter"] = TokenTextSplitter(
|
||||||
|
chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||||
|
)
|
||||||
|
super().__init__(*args, **kwargs)
|
|
@ -1,3 +1,5 @@
|
||||||
from .base import BaseAgent
|
from .base import BaseAgent
|
||||||
|
from .react.agent import ReactAgent
|
||||||
|
from .rewoo.agent import RewooAgent
|
||||||
|
|
||||||
__all__ = ["BaseAgent"]
|
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent"]
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
from typing import List, Optional
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
|
@ -9,6 +10,9 @@ from ..documents.base import Document
|
||||||
from ..embeddings import BaseEmbeddings
|
from ..embeddings import BaseEmbeddings
|
||||||
from ..vectorstores import BaseVectorStore
|
from ..vectorstores import BaseVectorStore
|
||||||
|
|
||||||
|
VECTOR_STORE_FNAME = "vectorstore"
|
||||||
|
DOC_STORE_FNAME = "docstore"
|
||||||
|
|
||||||
|
|
||||||
class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||||
"""Ingest the document, run through the embedding, and store the embedding in a
|
"""Ingest the document, run through the embedding, and store the embedding in a
|
||||||
|
@ -20,7 +24,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vector_store: Param[BaseVectorStore] = Param()
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
doc_store: Optional[BaseDocumentStore] = None
|
doc_store: Param[BaseDocumentStore] = Param()
|
||||||
embedding: Node[BaseEmbeddings] = Node()
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
|
|
||||||
# TODO: refer to llama_index's storage as well
|
# TODO: refer to llama_index's storage as well
|
||||||
|
@ -30,7 +34,7 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||||
self.run_batch_document([document])
|
self.run_batch_document([document])
|
||||||
|
|
||||||
def run_batch_raw(self, text: List[str]) -> None:
|
def run_batch_raw(self, text: List[str]) -> None:
|
||||||
documents = [Document(t, id_=str(uuid.uuid4())) for t in text]
|
documents = [Document(text=t, id_=str(uuid.uuid4())) for t in text]
|
||||||
self.run_batch_document(documents)
|
self.run_batch_document(documents)
|
||||||
|
|
||||||
def run_document(self, text: Document) -> None:
|
def run_document(self, text: Document) -> None:
|
||||||
|
@ -57,13 +61,31 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def persist(self, path: str):
|
def save(
|
||||||
|
self,
|
||||||
|
path: Union[str, Path],
|
||||||
|
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||||
|
docstore_fname: str = DOC_STORE_FNAME,
|
||||||
|
):
|
||||||
"""Save the whole state of the indexing pipeline vector store and all
|
"""Save the whole state of the indexing pipeline vector store and all
|
||||||
necessary information to disk
|
necessary information to disk
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): path to save the state
|
path (str): path to save the state
|
||||||
"""
|
"""
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
self.vector_store.save(path / vectorstore_fname)
|
||||||
|
self.doc_store.save(path / docstore_fname)
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(
|
||||||
|
self,
|
||||||
|
path: Union[str, Path],
|
||||||
|
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||||
|
docstore_fname: str = DOC_STORE_FNAME,
|
||||||
|
):
|
||||||
"""Load all information from disk to an object"""
|
"""Load all information from disk to an object"""
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
self.vector_store.load(path / vectorstore_fname)
|
||||||
|
self.doc_store.load(path / docstore_fname)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import List, Optional
|
from pathlib import Path
|
||||||
|
from typing import List, Union
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
|
@ -9,6 +10,9 @@ from ..documents.base import Document, RetrievedDocument
|
||||||
from ..embeddings import BaseEmbeddings
|
from ..embeddings import BaseEmbeddings
|
||||||
from ..vectorstores import BaseVectorStore
|
from ..vectorstores import BaseVectorStore
|
||||||
|
|
||||||
|
VECTOR_STORE_FNAME = "vectorstore"
|
||||||
|
DOC_STORE_FNAME = "docstore"
|
||||||
|
|
||||||
|
|
||||||
class BaseRetrieval(BaseComponent):
|
class BaseRetrieval(BaseComponent):
|
||||||
"""Define the base interface of a retrieval pipeline"""
|
"""Define the base interface of a retrieval pipeline"""
|
||||||
|
@ -38,7 +42,7 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
|
||||||
"""Retrieve list of documents from vector store"""
|
"""Retrieve list of documents from vector store"""
|
||||||
|
|
||||||
vector_store: Param[BaseVectorStore] = Param()
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
doc_store: Optional[BaseDocumentStore] = None
|
doc_store: Param[BaseDocumentStore] = Param()
|
||||||
embedding: Node[BaseEmbeddings] = Node()
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
# TODO: refer to llama_index's storage as well
|
# TODO: refer to llama_index's storage as well
|
||||||
|
|
||||||
|
@ -86,13 +90,31 @@ class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def persist(self, path: str):
|
def save(
|
||||||
|
self,
|
||||||
|
path: Union[str, Path],
|
||||||
|
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||||
|
docstore_fname: str = DOC_STORE_FNAME,
|
||||||
|
):
|
||||||
"""Save the whole state of the indexing pipeline vector store and all
|
"""Save the whole state of the indexing pipeline vector store and all
|
||||||
necessary information to disk
|
necessary information to disk
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
path (str): path to save the state
|
path (str): path to save the state
|
||||||
"""
|
"""
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
self.vector_store.save(path / vectorstore_fname)
|
||||||
|
self.doc_store.save(path / docstore_fname)
|
||||||
|
|
||||||
def load(self, path: str):
|
def load(
|
||||||
|
self,
|
||||||
|
path: Union[str, Path],
|
||||||
|
vectorstore_fname: str = VECTOR_STORE_FNAME,
|
||||||
|
docstore_fname: str = DOC_STORE_FNAME,
|
||||||
|
):
|
||||||
"""Load all information from disk to an object"""
|
"""Load all information from disk to an object"""
|
||||||
|
if isinstance(path, str):
|
||||||
|
path = Path(path)
|
||||||
|
self.vector_store.load(path / vectorstore_fname)
|
||||||
|
self.doc_store.load(path / docstore_fname)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from langchain.agents import Tool as LCTool
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
|
@ -87,6 +88,10 @@ class BaseTool(BaseComponent):
|
||||||
)
|
)
|
||||||
return observation
|
return observation
|
||||||
|
|
||||||
|
def to_langchain_format(self) -> LCTool:
|
||||||
|
"""Convert this tool to Langchain format to use with its agent"""
|
||||||
|
return LCTool(name=self.name, description=self.description, func=self.run)
|
||||||
|
|
||||||
def run_raw(
|
def run_raw(
|
||||||
self,
|
self,
|
||||||
tool_input: Union[str, Dict],
|
tool_input: Union[str, Dict],
|
||||||
|
@ -122,6 +127,15 @@ class BaseTool(BaseComponent):
|
||||||
"""Tool does not support processing batch"""
|
"""Tool does not support processing batch"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
|
||||||
|
"""Wrapper for Langchain Tool"""
|
||||||
|
new_tool = BaseTool(
|
||||||
|
name=langchain_tool.name, description=langchain_tool.description
|
||||||
|
)
|
||||||
|
new_tool._run_tool = langchain_tool._run # type: ignore
|
||||||
|
return new_tool
|
||||||
|
|
||||||
|
|
||||||
class ComponentTool(BaseTool):
|
class ComponentTool(BaseTool):
|
||||||
"""
|
"""
|
||||||
|
@ -130,6 +144,11 @@ class ComponentTool(BaseTool):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
component: BaseComponent
|
component: BaseComponent
|
||||||
|
postprocessor: Optional[Callable] = None
|
||||||
|
|
||||||
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
return self.component(*args, **kwargs)
|
output = self.component(*args, **kwargs)
|
||||||
|
if self.postprocessor:
|
||||||
|
output = self.postprocessor(output)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
|
@ -67,6 +67,9 @@ class ChromaVectorStore(LlamaIndexVectorStore):
|
||||||
collection_name = self._client.client.name
|
collection_name = self._client.client.name
|
||||||
self._client.client._client.delete_collection(collection_name)
|
self._client.client._client.delete_collection(collection_name)
|
||||||
|
|
||||||
|
def count(self) -> int:
|
||||||
|
return self._collection.count()
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
|
@ -44,9 +44,7 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
|
||||||
"""
|
"""
|
||||||
self._client.persist(persist_path=save_path, fs=fs)
|
self._client.persist(persist_path=save_path, fs=fs)
|
||||||
|
|
||||||
def load(
|
def load(self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None):
|
||||||
self, load_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
|
|
||||||
) -> "InMemoryVectorStore":
|
|
||||||
|
|
||||||
"""Create a SimpleKVStore from a load directory.
|
"""Create a SimpleKVStore from a load directory.
|
||||||
|
|
||||||
|
@ -54,4 +52,4 @@ class InMemoryVectorStore(LlamaIndexVectorStore):
|
||||||
load_path: Path of loading vector.
|
load_path: Path of loading vector.
|
||||||
fs: An abstract super-class for pythonic file-systems
|
fs: An abstract super-class for pythonic file-systems
|
||||||
"""
|
"""
|
||||||
return self._client.from_persist_path(persist_path=load_path, fs=fs)
|
self._client = self._client.from_persist_path(persist_path=load_path, fs=fs)
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -53,6 +53,7 @@ setuptools.setup(
|
||||||
"openai",
|
"openai",
|
||||||
"chromadb",
|
"chromadb",
|
||||||
"wikipedia",
|
"wikipedia",
|
||||||
|
"duckduckgo-search",
|
||||||
"googlesearch-python",
|
"googlesearch-python",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"pytest-mock",
|
"pytest-mock",
|
||||||
|
|
|
@ -1,11 +1,19 @@
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
from kotaemon.pipelines.agents.react import ReactAgent
|
from kotaemon.pipelines.agents.react import ReactAgent
|
||||||
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
||||||
from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool
|
from kotaemon.pipelines.tools import (
|
||||||
|
BaseTool,
|
||||||
|
GoogleSearchTool,
|
||||||
|
LLMTool,
|
||||||
|
WikipediaTool,
|
||||||
|
)
|
||||||
|
|
||||||
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
|
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
|
||||||
|
|
||||||
_openai_chat_completion_responses_rewoo = [
|
_openai_chat_completion_responses_rewoo = [
|
||||||
{
|
{
|
||||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
@ -73,19 +81,61 @@ _openai_chat_completion_responses_react = [
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_openai_chat_completion_responses_react_langchain_tool = [
|
||||||
|
{
|
||||||
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1692338378,
|
||||||
|
"model": "gpt-35-turbo",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": text,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
|
}
|
||||||
|
for text in [
|
||||||
|
(
|
||||||
|
"I don't have prior knowledge about Cinnamon AI company, "
|
||||||
|
"so I should gather information about it.\n"
|
||||||
|
"Action: Wikipedia\n"
|
||||||
|
"Action Input: Cinnamon AI company\n"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"The information retrieved from Wikipedia is not "
|
||||||
|
"about Cinnamon AI company, but about Blue Prism, "
|
||||||
|
"a British multinational software corporation. "
|
||||||
|
"I need to try another source to gather information "
|
||||||
|
"about Cinnamon AI company.\n"
|
||||||
|
"Action: duckduckgo_search\n"
|
||||||
|
"Action Input: Cinnamon AI company\n"
|
||||||
|
),
|
||||||
|
FINAL_RESPONSE_TEXT,
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
@patch(
|
|
||||||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
@pytest.fixture
|
||||||
side_effect=_openai_chat_completion_responses_rewoo,
|
def llm():
|
||||||
)
|
return AzureChatOpenAI(
|
||||||
def test_rewoo_agent(openai_completion):
|
|
||||||
llm = AzureChatOpenAI(
|
|
||||||
openai_api_base="https://dummy.openai.azure.com/",
|
openai_api_base="https://dummy.openai.azure.com/",
|
||||||
openai_api_key="dummy",
|
openai_api_key="dummy",
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
deployment_name="dummy-q2",
|
deployment_name="dummy-q2",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_responses_rewoo,
|
||||||
|
)
|
||||||
|
def test_rewoo_agent(openai_completion, llm):
|
||||||
plugins = [
|
plugins = [
|
||||||
GoogleSearchTool(),
|
GoogleSearchTool(),
|
||||||
WikipediaTool(),
|
WikipediaTool(),
|
||||||
|
@ -103,14 +153,7 @@ def test_rewoo_agent(openai_completion):
|
||||||
"openai.api_resources.chat_completion.ChatCompletion.create",
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
side_effect=_openai_chat_completion_responses_react,
|
side_effect=_openai_chat_completion_responses_react,
|
||||||
)
|
)
|
||||||
def test_react_agent(openai_completion):
|
def test_react_agent(openai_completion, llm):
|
||||||
llm = AzureChatOpenAI(
|
|
||||||
openai_api_base="https://dummy.openai.azure.com/",
|
|
||||||
openai_api_key="dummy",
|
|
||||||
openai_api_version="2023-03-15-preview",
|
|
||||||
deployment_name="dummy-q2",
|
|
||||||
temperature=0,
|
|
||||||
)
|
|
||||||
plugins = [
|
plugins = [
|
||||||
GoogleSearchTool(),
|
GoogleSearchTool(),
|
||||||
WikipediaTool(),
|
WikipediaTool(),
|
||||||
|
@ -121,3 +164,47 @@ def test_react_agent(openai_completion):
|
||||||
response = agent("Tell me about Cinnamon AI company")
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
openai_completion.assert_called()
|
openai_completion.assert_called()
|
||||||
assert response.output == FINAL_RESPONSE_TEXT
|
assert response.output == FINAL_RESPONSE_TEXT
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_responses_react,
|
||||||
|
)
|
||||||
|
def test_react_agent_langchain(openai_completion, llm):
|
||||||
|
from langchain.agents import AgentType, initialize_agent
|
||||||
|
|
||||||
|
plugins = [
|
||||||
|
GoogleSearchTool(),
|
||||||
|
WikipediaTool(),
|
||||||
|
LLMTool(llm=llm),
|
||||||
|
]
|
||||||
|
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
|
||||||
|
agent = initialize_agent(
|
||||||
|
langchain_plugins,
|
||||||
|
llm.agent,
|
||||||
|
agent=AgentType.OPENAI_FUNCTIONS,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
|
openai_completion.assert_called()
|
||||||
|
assert response
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.api_resources.chat_completion.ChatCompletion.create",
|
||||||
|
side_effect=_openai_chat_completion_responses_react_langchain_tool,
|
||||||
|
)
|
||||||
|
def test_react_agent_with_langchain_tools(openai_completion, llm):
|
||||||
|
from langchain.tools import DuckDuckGoSearchRun, WikipediaQueryRun
|
||||||
|
from langchain.utilities import WikipediaAPIWrapper
|
||||||
|
|
||||||
|
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
|
||||||
|
search = DuckDuckGoSearchRun()
|
||||||
|
|
||||||
|
langchain_plugins = [wikipedia, search]
|
||||||
|
plugins = [BaseTool.from_langchain_format(tool) for tool in langchain_plugins]
|
||||||
|
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
|
||||||
|
|
||||||
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
|
openai_completion.assert_called()
|
||||||
|
assert response.output == FINAL_RESPONSE_TEXT
|
||||||
|
|
|
@ -116,8 +116,8 @@ class TestInMemoryVectorStore:
|
||||||
"3" not in data["text_id_to_ref_doc_id"]
|
"3" not in data["text_id_to_ref_doc_id"]
|
||||||
), "delete function does not delete data completely"
|
), "delete function does not delete data completely"
|
||||||
db2 = InMemoryVectorStore()
|
db2 = InMemoryVectorStore()
|
||||||
output = db2.load(load_path=tmp_path / "test_save_load_delete.json")
|
db2.load(load_path=tmp_path / "test_save_load_delete.json")
|
||||||
assert output.get("2") == [
|
assert db2.get("2") == [
|
||||||
0.4,
|
0.4,
|
||||||
0.5,
|
0.5,
|
||||||
0.6,
|
0.6,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user