[AUR-431] Add ReAct Agent (#34)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

* add base Agent Interface, add ReWoo Agent

* minor update

* update test

* fix typo

* remove unneeded print

* update rewoo agent

* add LLMTool

* update BaseAgent type

* add ReAct agent

* add ReAct agent

* minor update

* minor update

* minor update

* minor update

* update docstring

* fix max_iteration

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-10-02 11:29:12 +07:00 committed by GitHub
parent 91048770fa
commit 3cceec63ef
9 changed files with 340 additions and 27 deletions

View File

@ -34,16 +34,16 @@ class LangchainChatLLM(ChatLLM):
def agent(self) -> BaseLanguageModel: def agent(self) -> BaseLanguageModel:
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)
def run_raw(self, text: str) -> LLMInterface: def run_raw(self, text: str, **kwargs) -> LLMInterface:
message = HumanMessage(content=text) message = HumanMessage(content=text)
return self.run_document([message]) return self.run_document([message], **kwargs)
def run_batch_raw(self, text: List[str]) -> List[LLMInterface]: def run_batch_raw(self, text: List[str], **kwargs) -> List[LLMInterface]:
inputs = [[HumanMessage(content=each)] for each in text] inputs = [[HumanMessage(content=each)] for each in text]
return self.run_batch_document(inputs) return self.run_batch_document(inputs, **kwargs)
def run_document(self, text: List[Message]) -> LLMInterface: def run_document(self, text: List[Message], **kwargs) -> LLMInterface:
pred = self.agent.generate([text]) # type: ignore pred = self.agent.generate([text], **kwargs) # type: ignore
return LLMInterface( return LLMInterface(
text=[each.text for each in pred.generations[0]], text=[each.text for each in pred.generations[0]],
completion_tokens=pred.llm_output["token_usage"]["completion_tokens"], completion_tokens=pred.llm_output["token_usage"]["completion_tokens"],
@ -52,20 +52,22 @@ class LangchainChatLLM(ChatLLM):
logits=[], logits=[],
) )
def run_batch_document(self, text: List[List[Message]]) -> List[LLMInterface]: def run_batch_document(
self, text: List[List[Message]], **kwargs
) -> List[LLMInterface]:
outputs = [] outputs = []
for each_text in text: for each_text in text:
outputs.append(self.run_document(each_text)) outputs.append(self.run_document(each_text, **kwargs))
return outputs return outputs
def is_document(self, text) -> bool: def is_document(self, text, **kwargs) -> bool:
if isinstance(text, str): if isinstance(text, str):
return False return False
elif isinstance(text, List) and isinstance(text[0], str): elif isinstance(text, List) and isinstance(text[0], str):
return False return False
return True return True
def is_batch(self, text) -> bool: def is_batch(self, text, **kwargs) -> bool:
if isinstance(text, str): if isinstance(text, str):
return False return False
elif isinstance(text, List): elif isinstance(text, List):

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel from pydantic import BaseModel
@ -50,7 +50,7 @@ class AgentOutput(BaseModel):
class BaseAgent(BaseTool): class BaseAgent(BaseTool):
name: str name: str
"""Name of the agent.""" """Name of the agent."""
type: AgentType agent_type: AgentType
"""Agent type, must be one of AgentType""" """Agent type, must be one of AgentType"""
description: str description: str
"""Description used to tell the model how/when/why to use the agent. """Description used to tell the model how/when/why to use the agent.
@ -59,7 +59,7 @@ class BaseAgent(BaseTool):
llm: Union[BaseLLM, Dict[str, BaseLLM]] llm: Union[BaseLLM, Dict[str, BaseLLM]]
"""Specify LLM to be used in the model, cam be a dict to supply different """Specify LLM to be used in the model, cam be a dict to supply different
LLMs to multiple purposes in the agent""" LLMs to multiple purposes in the agent"""
prompt_template: Union[PromptTemplate, Dict[str, PromptTemplate]] prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]]
"""A prompt template or a dict to supply different prompt to the agent """A prompt template or a dict to supply different prompt to the agent
""" """
plugins: List[BaseTool] plugins: List[BaseTool]

View File

@ -0,0 +1,3 @@
from .agent import ReactAgent
__all__ = ["ReactAgent"]

View File

@ -0,0 +1,188 @@
import logging
import re
from typing import Dict, List, Optional, Tuple, Type, Union
from pydantic import BaseModel, create_model
from kotaemon.prompt.template import PromptTemplate
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import AgentAction, AgentFinish
FINAL_ANSWER_ACTION = "Final Answer:"
class ReactAgent(BaseAgent):
"""
Sequential ReactAgent class inherited from BaseAgent.
Implementing ReAct agent paradigm https://arxiv.org/pdf/2210.03629.pdf
"""
name: str = "ReactAgent"
agent_type: AgentType = AgentType.react
description: str = "ReactAgent for answering multi-step reasoning questions"
llm: Union[BaseLLM, Dict[str, BaseLLM]]
prompt_template: Optional[PromptTemplate] = None
plugins: List[BaseTool] = list()
examples: Dict[str, Union[str, List[str]]] = dict()
args_schema: Optional[Type[BaseModel]] = create_model(
"ReactArgsSchema", instruction=(str, ...)
)
intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = []
"""List of AgentAction and observation (tool) output"""
max_iterations = 10
strict_decode: bool = False
def _compose_plugin_description(self) -> str:
"""
Compose the worker prompt from the workers.
Example:
toolname1[input]: tool1 description
toolname2[input]: tool2 description
"""
prompt = ""
try:
for plugin in self.plugins:
prompt += f"{plugin.name}[input]: {plugin.description}\n"
except Exception:
raise ValueError("Worker must have a name and description.")
return prompt
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[Union[AgentAction, AgentFinish], str]] = []
) -> str:
"""Construct the scratchpad that lets the agent continue its thought process."""
thoughts = ""
for action, observation in intermediate_steps:
thoughts += action.log
thoughts += f"\nObservation: {observation}\nThought:"
return thoughts
def _parse_output(self, text: str) -> Optional[Union[AgentAction, AgentFinish]]:
"""
Parse text output from LLM for the next Action or Final Answer
Using Regex to parse "Action:\n Action Input:\n" for the next Action
Using FINAL_ANSWER_ACTION to parse Final Answer
Args:
text[str]: input text to parse
"""
includes_answer = FINAL_ANSWER_ACTION in text
regex = (
r"Action\s*\d*\s*:[\s]*(.*?)[\s]*Action\s*\d*\s*Input\s*\d*\s*:[\s]*(.*)"
)
action_match = re.search(regex, text, re.DOTALL)
action_output: Optional[Union[AgentAction, AgentFinish]] = None
if action_match:
if includes_answer:
raise Exception(
"Parsing LLM output produced both a final answer "
f"and a parse-able action: {text}"
)
action = action_match.group(1).strip()
action_input = action_match.group(2)
tool_input = action_input.strip(" ")
# ensure if its a well formed SQL query we don't remove any trailing " chars
if tool_input.startswith("SELECT ") is False:
tool_input = tool_input.strip('"')
action_output = AgentAction(action, tool_input, text)
elif includes_answer:
action_output = AgentFinish(
{"output": text.split(FINAL_ANSWER_ACTION)[-1].strip()}, text
)
else:
if self.strict_decode:
raise Exception(f"Could not parse LLM output: `{text}`")
else:
action_output = AgentFinish({"output": text}, text)
return action_output
def _compose_prompt(self, instruction) -> str:
"""
Compose the prompt from template, worker description, examples and instruction.
"""
agent_scratchpad = self._construct_scratchpad(self.intermediate_steps)
tool_description = self._compose_plugin_description()
tool_names = ", ".join([plugin.name for plugin in self.plugins])
if self.prompt_template is None:
from .prompt import zero_shot_react_prompt
self.prompt_template = zero_shot_react_prompt
return self.prompt_template.populate(
instruction=instruction,
agent_scratchpad=agent_scratchpad,
tool_description=tool_description,
tool_names=tool_names,
)
def _format_function_map(self) -> Dict[str, BaseTool]:
"""Format the function map for the open AI function API.
Return:
Dict[str, Callable]: The function map.
"""
# Map the function name to the real function object.
function_map = {}
for plugin in self.plugins:
function_map[plugin.name] = plugin
return function_map
def clear(self):
"""
Clear and reset the agent.
"""
self.intermediate_steps = []
def run(self, instruction, max_iterations=None):
"""
Run the agent with the given instruction.
Args:
instruction: Instruction to run the agent with.
max_iterations: Maximum number of iterations
of reasoning steps, defaults to 10.
Return:
AgentOutput object.
"""
if not max_iterations:
max_iterations = self.max_iterations
assert max_iterations > 0
self.clear()
logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
for _ in range(max_iterations):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
response_text = response.text[0]
logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text)
if action_step is None:
raise ValueError("Invalid action")
is_finished_chain = isinstance(action_step, AgentFinish)
if is_finished_chain:
result = ""
else:
assert isinstance(action_step, AgentAction)
action_name = action_step.tool
tool_input = action_step.tool_input
logging.info(f"Action: {action_name}")
logging.info(f"Tool Input: {tool_input}")
result = self._format_function_map()[action_name](tool_input)
logging.info(f"Result: {result}")
self.intermediate_steps.append((action_step, result))
if is_finished_chain:
break
return AgentOutput(
output=response_text, cost=total_cost, token_usage=total_token
)

View File

@ -0,0 +1,28 @@
# flake8: noqa
from kotaemon.prompt.template import PromptTemplate
zero_shot_react_prompt = PromptTemplate(
template="""Answer the following questions as best you can. You have access to the following tools:
{tool_description}.
Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Observation: the result of the action
... (this Thought/Action/Action Input/Observation can repeat N times)
#Thought: I now know the final answer
Final Answer: the final answer to the original input question
Begin! After each Action Input.
Question: {instruction}
Thought:{agent_scratchpad}
"""
)

View File

@ -21,7 +21,7 @@ class RewooAgent(BaseAgent):
Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf""" Implementing ReWOO paradigm https://arxiv.org/pdf/2305.18323.pdf"""
name: str = "RewooAgent" name: str = "RewooAgent"
type: AgentType = AgentType.rewoo agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions" description: str = "RewooAgent for answering multi-step reasoning questions"
llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx} llm: Union[BaseLLM, Dict[str, BaseLLM]] # {"Planner": xxx, "Solver": xxx}
prompt_template: Dict[ prompt_template: Dict[

View File

@ -1,5 +1,6 @@
from .base import BaseTool, ComponentTool from .base import BaseTool, ComponentTool
from .google import GoogleSearchTool from .google import GoogleSearchTool
from .llm import LLMTool
from .wikipedia import WikipediaTool from .wikipedia import WikipediaTool
__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool"] __all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool", "LLMTool"]

View File

@ -0,0 +1,36 @@
from typing import AnyStr, Optional, Type, Union
from pydantic import BaseModel, Field
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.llms.completions.base import LLM
from .base import BaseTool, ToolException
BaseLLM = Union[ChatLLM, LLM]
class LLMArgs(BaseModel):
query: str = Field(..., description="a search question or prompt")
class LLMTool(BaseTool):
name = "llm"
description = (
"A pretrained LLM like yourself. Useful when you need to act with "
"general world knowledge and common sense. Prioritize it when you "
"are confident in solving the problem "
"yourself. Input can be any instruction."
)
llm: BaseLLM = AzureChatOpenAI()
args_schema: Optional[Type[BaseModel]] = LLMArgs
def _run_tool(self, query: AnyStr) -> str:
output = None
try:
response = self.llm(query)
except ValueError:
raise ToolException("LLM Tool call failed")
output = response.text[0]
return output

View File

@ -1,11 +1,12 @@
from unittest.mock import patch from unittest.mock import patch
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.react import ReactAgent
from kotaemon.pipelines.agents.rewoo import RewooAgent from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import GoogleSearchTool, WikipediaTool from kotaemon.pipelines.tools import GoogleSearchTool, LLMTool, WikipediaTool
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!" FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
_openai_chat_completion_responses = [ _openai_chat_completion_responses_rewoo = [
{ {
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion", "object": "chat.completion",
@ -17,15 +18,24 @@ _openai_chat_completion_responses = [
"finish_reason": "stop", "finish_reason": "stop",
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": "#Plan1: Search for Cinnamon AI company on Google\n" "content": text,
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]",
}, },
} }
], ],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}, }
for text in [
(
"#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
),
FINAL_RESPONSE_TEXT,
]
]
_openai_chat_completion_responses_react = [
{ {
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion", "object": "chat.completion",
@ -37,18 +47,36 @@ _openai_chat_completion_responses = [
"finish_reason": "stop", "finish_reason": "stop",
"message": { "message": {
"role": "assistant", "role": "assistant",
"content": FINAL_RESPONSE_TEXT, "content": text,
}, },
} }
], ],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}, }
for text in [
(
"I don't have prior knowledge about Cinnamon AI company, "
"so I should gather information about it.\n"
"Action: wikipedia\n"
"Action Input: Cinnamon AI company\n"
),
(
"The information retrieved from Wikipedia is not "
"about Cinnamon AI company, but about Blue Prism, "
"a British multinational software corporation. "
"I need to try another source to gather information "
"about Cinnamon AI company.\n"
"Action: google_search\n"
"Action Input: Cinnamon AI company\n"
),
FINAL_RESPONSE_TEXT,
]
] ]
@patch( @patch(
"openai.api_resources.chat_completion.ChatCompletion.create", "openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses, side_effect=_openai_chat_completion_responses_rewoo,
) )
def test_rewoo_agent(openai_completion): def test_rewoo_agent(openai_completion):
llm = AzureChatOpenAI( llm = AzureChatOpenAI(
@ -58,11 +86,38 @@ def test_rewoo_agent(openai_completion):
deployment_name="dummy-q2", deployment_name="dummy-q2",
temperature=0, temperature=0,
) )
plugins = [
plugins = [GoogleSearchTool(), WikipediaTool()] GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = RewooAgent(llm=llm, plugins=plugins) agent = RewooAgent(llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called() openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT assert response.output == FINAL_RESPONSE_TEXT
@patch(
"openai.api_resources.chat_completion.ChatCompletion.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_react_agent(openai_completion):
llm = AzureChatOpenAI(
openai_api_base="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2",
temperature=0,
)
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = ReactAgent(llm=llm, plugins=plugins, max_iterations=4)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response.output == FINAL_RESPONSE_TEXT