Refactor agents and tools (#91)

* Move tools to agents

* Move agents to dedicate place

* Remove subclassing BaseAgent from BaseTool
This commit is contained in:
Nguyen Trung Duc (john) 2023-11-30 09:52:08 +07:00 committed by GitHub
parent 4256030b4f
commit 8e3a1d193f
24 changed files with 126 additions and 124 deletions

View File

@ -0,0 +1,68 @@
from enum import Enum
from typing import Optional, Union
from theflow import Node, Param
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from .tools import BaseTool
BaseLLM = Union[ChatLLM, LLM]
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
@staticmethod
def get_agent_class(_type: "AgentType"):
"""
Get agent class from agent type.
:param _type: agent type
:return: agent class
"""
if _type == AgentType.rewoo:
from .rewoo.agent import RewooAgent
return RewooAgent
else:
raise ValueError(f"Unknown agent type: {_type}")
class BaseAgent(BaseComponent):
"""Define base agent interface"""
name: str = Param(help="Name of the agent.")
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
description: str = Param(
help="Description used to tell the model how/when/why to use the agent. "
"You can provide few-shot examples as a part of the description. This will be "
"input to the prompt of LLM."
)
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node(
help="Specify LLM to be used in the model, cam be a dict to supply different "
"LLMs to multiple purposes in the agent"
)
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
help="A prompt template or a dict to supply different prompt to the agent"
)
plugins: list[BaseTool] = Param(
default_callback=lambda _: [],
help="List of plugins / tools to be used in the agent",
)
def add_tools(self, tools: list[BaseTool]) -> None:
"""Helper method to add tools and update agent state if needed"""
self.plugins.extend(tools)

View File

@ -1,14 +1,13 @@
from typing import List, Optional, Type
from typing import List, Optional
from langchain.agents import AgentType as LCAgentType
from langchain.agents import initialize_agent
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
from pydantic import BaseModel, create_model
from kotaemon.agents.tools import BaseTool
from kotaemon.base.schema import Document
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.pipelines.tools import BaseTool
from .base import AgentType, BaseAgent
@ -19,9 +18,6 @@ class LangchainAgent(BaseAgent):
name: str = "LangchainAgent"
agent_type: AgentType
description: str = "LangchainAgent for answering multi-step reasoning questions"
args_schema: Optional[Type[BaseModel]] = create_model(
"LangchainArgsSchema", instruction=(str, ...)
)
AGENT_TYPE_MAP = {
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
@ -69,7 +65,7 @@ class LangchainAgent(BaseAgent):
self.update_agent_tools()
return
def _run_tool(self, instruction: str) -> Document:
def run(self, instruction: str) -> Document:
assert (
self.agent is not None
), "Lanchain AgentExecutor is not correclty initialized"

View File

@ -1,8 +1,8 @@
import logging
import re
from typing import Dict, List, Optional, Tuple, Type, Union
from typing import Optional
from pydantic import BaseModel, create_model
from theflow import Param
from kotaemon.base.schema import Document
from kotaemon.llms import PromptTemplate
@ -22,15 +22,18 @@ class ReactAgent(BaseAgent):
name: str = "ReactAgent"
agent_type: AgentType = AgentType.react
description: str = "ReactAgent for answering multi-step reasoning questions"
llm: Union[BaseLLM, Dict[str, BaseLLM]]
llm: BaseLLM | dict[str, BaseLLM]
prompt_template: Optional[PromptTemplate] = None
plugins: List[BaseTool] = list()
examples: Dict[str, Union[str, List[str]]] = dict()
args_schema: Optional[Type[BaseModel]] = create_model(
"ReactArgsSchema", instruction=(str, ...)
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. "
)
examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent. "
)
intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = Param(
default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output",
)
intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = []
"""List of AgentAction and observation (tool) output"""
max_iterations = 10
strict_decode: bool = False
@ -51,7 +54,7 @@ class ReactAgent(BaseAgent):
return prompt
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = []
self, intermediate_steps: list[tuple[AgentAction | AgentFinish, str]] = []
) -> str:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts = ""
@ -60,7 +63,7 @@ class ReactAgent(BaseAgent):
thoughts += f"\nObservation: {observation}\nThought:"
return thoughts
def _parse_output(self, text: str) -> Optional[Union[AgentAction, AgentFinish]]:
def _parse_output(self, text: str) -> Optional[AgentAction | AgentFinish]:
"""
Parse text output from LLM for the next Action or Final Answer
Using Regex to parse "Action:\n Action Input:\n" for the next Action
@ -74,7 +77,7 @@ class ReactAgent(BaseAgent):
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)
action_match = re.search(regex, text, re.DOTALL)
action_output: Optional[Union[AgentAction, AgentFinish]] = None
action_output: Optional[AgentAction | AgentFinish] = None
if action_match:
if includes_answer:
raise Exception(
@ -120,7 +123,7 @@ class ReactAgent(BaseAgent):
tool_names=tool_names,
)
def _format_function_map(self) -> Dict[str, BaseTool]:
def _format_function_map(self) -> dict[str, BaseTool]:
"""Format the function map for the open AI function API.
Return:

View File

@ -1,9 +1,9 @@
import logging
import re
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any
from pydantic import BaseModel, create_model
from theflow import Param
from kotaemon.base.schema import Document
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
@ -23,16 +23,16 @@ class RewooAgent(BaseAgent):
name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions"
llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx}
prompt_template: Dict[
str, PromptTemplate
] = dict() # {"Planner": xxx, "Solver": xxx}
plugins: List[BaseTool] = list()
examples: Dict[
str, Union[str, List[str]]
] = dict() # {"Planner": xxx, "Solver": xxx}
args_schema: Optional[Type[BaseModel]] = create_model(
"RewooArgsSchema", instruction=(str, ...)
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
prompt_template: dict[str, PromptTemplate] = Param(
default_callback=lambda _: {},
help="A dict to supply different prompt to the agent.",
)
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="A list of plugins to be used in the model."
)
examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent."
)
def _get_llms(self):
@ -49,7 +49,7 @@ class RewooAgent(BaseAgent):
def _parse_plan_map(
self, planner_response: str
) -> Tuple[Dict[str, List[str]], Dict[str, str]]:
) -> tuple[dict[str, list[str]], dict[str, str]]:
"""
Parse planner output. It should be an n-to-n mapping from Plans to #Es.
This is because sometimes LLM cannot follow the strict output format.
@ -66,7 +66,7 @@ class RewooAgent(BaseAgent):
This function should also return a plan map.
Returns:
Tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
tuple[Dict[str, List[str]], Dict[str, str]]: A list of plan map
"""
valid_chunk = [
line
@ -74,8 +74,8 @@ class RewooAgent(BaseAgent):
if line.startswith("#Plan") or line.startswith("#E")
]
plan_to_es: Dict[str, List[str]] = dict()
plans: Dict[str, str] = dict()
plan_to_es: dict[str, list[str]] = dict()
plans: dict[str, str] = dict()
for line in valid_chunk:
if line.startswith("#Plan"):
plan = line.split(":", 1)[0].strip()
@ -88,7 +88,7 @@ class RewooAgent(BaseAgent):
def _parse_planner_evidences(
self, planner_response: str
) -> Tuple[Dict[str, str], List[List[str]]]:
) -> tuple[dict[str, str], list[list[str]]]:
"""
Parse planner output. This should return a mapping from #E to tool call.
It should also identify the level of each #E in dependency map.
@ -99,11 +99,11 @@ class RewooAgent(BaseAgent):
}, [[#E1, #E2], [#E3, #E4]]
Returns:
Tuple[dict[str, str], List[List[str]]]:
tuple[dict[str, str], List[List[str]]]:
A mapping from #E to tool call and a list of levels.
"""
evidences: Dict[str, str] = dict()
dependence: Dict[str, List[str]] = dict()
evidences: dict[str, str] = dict()
dependence: dict[str, list[str]] = dict()
for line in planner_response.splitlines():
if line.startswith("#E") and line[2].isdigit():
e, tool_call = line.split(":", 1)
@ -134,8 +134,8 @@ class RewooAgent(BaseAgent):
def _run_plugin(
self,
e: str,
planner_evidences: Dict[str, str],
worker_evidences: Dict[str, str],
planner_evidences: dict[str, str],
worker_evidences: dict[str, str],
output=BaseScratchPad(),
):
"""
@ -169,8 +169,8 @@ class RewooAgent(BaseAgent):
def _get_worker_evidence(
self,
planner_evidences: Dict[str, str],
evidences_level: List[List[str]],
planner_evidences: dict[str, str],
evidences_level: list[list[str]],
output=BaseScratchPad(),
) -> Any:
"""
@ -185,7 +185,7 @@ class RewooAgent(BaseAgent):
Returns:
A mapping from #E to tool call.
"""
worker_evidences: Dict[str, str] = dict()
worker_evidences: dict[str, str] = dict()
plugin_cost, plugin_token = 0.0, 0.0
with ThreadPoolExecutor() as pool:
for level in evidences_level:
@ -218,7 +218,7 @@ class RewooAgent(BaseAgent):
if p.name == name:
return p
def _run_tool(self, instruction: str, use_citation: bool = False) -> Document:
def run(self, instruction: str, use_citation: bool = False) -> Document:
"""
Run the agent with a given instruction.
"""

View File

@ -1,7 +1,8 @@
from typing import Any, List, Optional, Union
from ....base import BaseComponent
from ....llms import PromptTemplate
from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate
from ..base import BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt

View File

@ -1,4 +1,4 @@
from ...base import Document
from kotaemon.base import Document
def get_plugin_response_content(output) -> str:

View File

@ -1,61 +0,0 @@
from enum import Enum
from typing import Dict, List, Optional, Union
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.pipelines.tools import BaseTool
BaseLLM = Union[ChatLLM, LLM]
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
@staticmethod
def get_agent_class(_type: "AgentType"):
"""
Get agent class from agent type.
:param _type: agent type
:return: agent class
"""
if _type == AgentType.rewoo:
from .rewoo.agent import RewooAgent
return RewooAgent
else:
raise ValueError(f"Unknown agent type: {_type}")
class BaseAgent(BaseTool):
name: str
"""Name of the agent."""
agent_type: AgentType
"""Agent type, must be one of AgentType"""
description: str
"""Description used to tell the model how/when/why to use the agent.
You can provide few-shot examples as a part of the description. This will be
input to the prompt of LLM."""
llm: Union[BaseLLM, Dict[str, BaseLLM]]
"""Specify LLM to be used in the model, cam be a dict to supply different
LLMs to multiple purposes in the agent"""
prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]]
"""A prompt template or a dict to supply different prompt to the agent
"""
plugins: List[BaseTool] = []
"""List of plugins / tools to be used in the agent
"""
def add_tools(self, tools: List[BaseTool]) -> None:
"""Helper method to add tools and update agent state if needed"""
self.plugins.extend(tools)

View File

@ -8,6 +8,7 @@ from llama_index.readers.base import BaseReader
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.agents import BaseAgent
from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices.extractors import BaseDocParser
@ -20,7 +21,6 @@ from kotaemon.loaders import (
OCRReader,
PandasExcelReader,
)
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.storages import (

View File

@ -5,16 +5,16 @@ from typing import List, Sequence
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.agents import BaseAgent
from kotaemon.agents.tools import ComponentTool
from kotaemon.base import BaseComponent
from kotaemon.base.schema import Document, RetrievedDocument
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices.rankings import BaseReranking
from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents import BaseAgent
from kotaemon.pipelines.citation import CitationPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool
from kotaemon.storages import (
BaseDocumentStore,
BaseVectorStore,

View File

@ -3,17 +3,12 @@ from unittest.mock import patch
import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.agents.base import AgentType
from kotaemon.agents.langchain import LangchainAgent
from kotaemon.agents.react import ReactAgent
from kotaemon.agents.rewoo import RewooAgent
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.base import AgentType
from kotaemon.pipelines.agents.langchain import LangchainAgent
from kotaemon.pipelines.agents.react import ReactAgent
from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import (
BaseTool,
GoogleSearchTool,
LLMTool,
WikipediaTool,
)
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"

View File

@ -4,11 +4,11 @@ from pathlib import Path
import pytest
from openai.resources.embeddings import Embeddings
from kotaemon.agents.tools import ComponentTool, GoogleSearchTool, WikipediaTool
from kotaemon.base import Document
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool, GoogleSearchTool, WikipediaTool
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f: