Feat/Add ReAct and ReWOO Reasoning Pipelines (#43)

* Add ReactAgentPipeline by wrapping the ReactAgent

* Implement stream processing for ReactAgentPipeline and RewooAgentPipeline

* Fix highlight_citation in Rewoo and remove highlight_citation from React

* Fix importing ktem.llms inside kotaemon

* fix: Change Rewoo::solver's output to LLMInterface instead of plain text

* Add more user_settings to the RewooAgentPipeline

* Fix LLMTool

* Add more user_settings to the ReactAgentPipeline

* Minor fix

* Stream the react agent immediately

* Yield the Rewoo progress to info panel

* Hide the agent in flowsettings

* Remove redundant comments

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Albert
2024-05-09 16:06:24 +07:00
committed by GitHub
parent ec11b54ff2
commit 466adf2d94
12 changed files with 1114 additions and 25 deletions

View File

@@ -253,5 +253,6 @@ class AgentOutput(LLMInterface):
text: str
type: str = "agent"
agent_type: AgentType
status: Literal["finished", "stopped", "failed"]
status: Literal["thinking", "finished", "stopped", "failed"]
error: Optional[str] = None
intermediate_steps: Optional[list] = None

View File

@@ -1,11 +1,15 @@
import logging
import re
from functools import partial
from typing import Optional
import tiktoken
from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool
from kotaemon.base import Param
from kotaemon.base import Document, Param
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import PromptTemplate
FINAL_ANSWER_ACTION = "Final Answer:"
@@ -22,6 +26,7 @@ class ReactAgent(BaseAgent):
description: str = "ReactAgent for answering multi-step reasoning questions"
llm: BaseLLM
prompt_template: Optional[PromptTemplate] = None
output_lang: str = "English"
plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. "
)
@@ -32,8 +37,18 @@ class ReactAgent(BaseAgent):
default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output",
)
max_iterations: int = 10
max_iterations: int = 5
strict_decode: bool = False
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=800,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
def _compose_plugin_description(self) -> str:
"""
@@ -119,6 +134,7 @@ class ReactAgent(BaseAgent):
agent_scratchpad=agent_scratchpad,
tool_description=tool_description,
tool_names=tool_names,
lang=self.output_lang,
)
def _format_function_map(self) -> dict[str, BaseTool]:
@@ -133,6 +149,20 @@ class ReactAgent(BaseAgent):
function_map[plugin.name] = plugin
return function_map
def _trim(self, text: str) -> str:
"""
Trim the text to the maximum token length.
"""
if isinstance(text, str):
texts = self.trim_func([Document(text=text)])
elif isinstance(text, Document):
texts = self.trim_func([text])
else:
raise ValueError("Invalid text type to trim")
trim_text = texts[0].text
logging.info(f"len (trimmed): {len(trim_text)}")
return trim_text
def clear(self):
"""
Clear and reset the agent.
@@ -183,6 +213,11 @@ class ReactAgent(BaseAgent):
logging.info(f"Action: {action_name}")
logging.info(f"Tool Input: {tool_input}")
result = self._format_function_map()[action_name](tool_input)
# trim the worker output to 1000 tokens, as we are appending
# all workers' logs and it can exceed the token limit if we
# don't limit each. Fix this number regarding to the LLM capacity.
result = self._trim(result)
logging.info(f"Result: {result}")
self.intermediate_steps.append((action_step, result))
@@ -202,3 +237,100 @@ class ReactAgent(BaseAgent):
intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
)
def stream(self, instruction, max_iterations=None):
"""
Stream the agent with the given instruction.
Args:
instruction: Instruction to run the agent with.
max_iterations: Maximum number of iterations
of reasoning steps, defaults to 10.
Return:
AgentOutput object.
"""
if not max_iterations:
max_iterations = self.max_iterations
assert max_iterations > 0
self.clear()
logging.info(f"Running {self.name} with instruction: {instruction}")
print(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
status = "failed"
response_text = None
for step_count in range(1, max_iterations + 1):
prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}")
print(f"Prompt: {prompt}")
response = self.llm(
prompt, stop=["Observation:"]
) # TODO: could cause bugs if llm doesn't have `stop` as a parameter
response_text = response.text
logging.info(f"Response: {response_text}")
print(f"Response: {response_text}")
action_step = self._parse_output(response_text)
if action_step is None:
raise ValueError("Invalid action")
is_finished_chain = isinstance(action_step, AgentFinish)
if is_finished_chain:
result = response_text
if "Final Answer:" in response_text:
result = response_text.split("Final Answer:")[-1].strip()
else:
assert isinstance(action_step, AgentAction)
action_name = action_step.tool
tool_input = action_step.tool_input
logging.info(f"Action: {action_name}")
print(f"Action: {action_name}")
logging.info(f"Tool Input: {tool_input}")
print(f"Tool Input: {tool_input}")
result = self._format_function_map()[action_name](tool_input)
# trim the worker output to 1000 tokens, as we are appending
# all workers' logs and it can exceed the token limit if we
# don't limit each. Fix this number regarding to the LLM capacity.
result = self._trim(result)
logging.info(f"Result: {result}")
print(f"Result: {result}")
self.intermediate_steps.append((action_step, result))
if is_finished_chain:
logging.info(f"Finished after {step_count} steps.")
status = "finished"
yield AgentOutput(
text=result,
agent_type=self.agent_type,
status=status,
intermediate_steps=self.intermediate_steps[-1],
)
break
else:
yield AgentOutput(
text="",
agent_type=self.agent_type,
status="thinking",
intermediate_steps=self.intermediate_steps[-1],
)
else:
status = "stopped"
yield AgentOutput(
text="",
agent_type=self.agent_type,
status=status,
intermediate_steps=self.intermediate_steps[-1],
)
return AgentOutput(
text=response_text,
agent_type=self.agent_type,
status=status,
total_tokens=total_token,
total_cost=total_cost,
intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
)

View File

@@ -3,7 +3,7 @@
from kotaemon.llms import PromptTemplate
zero_shot_react_prompt = PromptTemplate(
template="""Answer the following questions as best you can. You have access to the following tools:
template="""Answer the following questions as best you can. Give answer in {lang}. You have access to the following tools:
{tool_description}
Use the following format:
@@ -12,7 +12,7 @@ Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: the input to the action
Action Input: the input to the action, should be different from the action input of the same action in previous steps.
Observation: the result of the action

View File

@@ -1,14 +1,18 @@
import logging
import re
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
import tiktoken
from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.base import Node, Param
from kotaemon.indices.qa import CitationPipeline
from kotaemon.base import Document, Node, Param
from kotaemon.indices.qa.citation import CitationPipeline
from kotaemon.indices.splitters import TokenSplitter
from kotaemon.llms import BaseLLM, PromptTemplate
from .planner import Planner
@@ -22,6 +26,7 @@ class RewooAgent(BaseAgent):
name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions"
output_lang: str = "English"
planner_llm: BaseLLM
solver_llm: BaseLLM
prompt_template: dict[str, PromptTemplate] = Param(
@@ -34,6 +39,16 @@ class RewooAgent(BaseAgent):
examples: dict[str, str | list[str]] = Param(
default_callback=lambda _: {}, help="Examples to be used in the agent."
)
trim_func: TokenSplitter = TokenSplitter.withx(
chunk_size=3000,
chunk_overlap=0,
separator=" ",
tokenizer=partial(
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
allowed_special=set(),
disallowed_special="all",
),
)
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
def planner(self):
@@ -50,6 +65,7 @@ class RewooAgent(BaseAgent):
model=self.solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
output_lang=self.output_lang,
)
def _parse_plan_map(
@@ -159,8 +175,13 @@ class RewooAgent(BaseAgent):
tool_input = tool_input[:-1]
# find variables in input and replace with previous evidences
for var in re.findall(r"#E\d+", tool_input):
print("Tool input: ", tool_input)
print("Var: ", var)
print("Worker evidences: ", worker_evidences)
if var in worker_evidences:
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
tool_input = tool_input.replace(
var, worker_evidences.get(var, "") or ""
)
try:
selected_plugin = self._find_plugin(tool)
if selected_plugin is None:
@@ -216,7 +237,7 @@ class RewooAgent(BaseAgent):
resp = r.result()
plugin_cost += resp["plugin_cost"]
plugin_token += resp["plugin_token"]
worker_evidences[resp["e"]] = resp["evidence"]
worker_evidences[resp["e"]] = self._trim_evidence(resp["evidence"])
output.done()
return worker_evidences, plugin_cost, plugin_token
@@ -226,6 +247,13 @@ class RewooAgent(BaseAgent):
if p.name == name:
return p
def _trim_evidence(self, evidence: str):
if evidence:
texts = self.trim_func([Document(text=evidence)])
evidence = texts[0].text
logging.info(f"len (trimmed): {len(evidence)}")
return evidence
@BaseAgent.safeguard_run
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
"""
@@ -269,5 +297,69 @@ class RewooAgent(BaseAgent):
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
metadata={"citation": citation},
metadata={"citation": citation, "worker_log": worker_log},
)
def stream(self, instruction: str, use_citation: bool = False):
"""
Stream the agent with a given instruction.
"""
logging.info(f"Streaming {self.name} with instruction: {instruction}")
total_cost = 0.0
total_token = 0
# Plan
planner_output = self.planner(instruction)
planner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(planner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences(
planner_text_output
)
print("Planner output:", planner_text_output)
# Work
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
planner_evidences, evidence_level
)
worker_log = ""
for plan in plan_to_es:
worker_log += f"{plan}: {plans[plan]}\n"
current_progress = f"{plan}: {plans[plan]}\n"
for e in plan_to_es[plan]:
worker_log += f"{e}: {worker_evidences[e]}\n"
current_progress += f"{e}: {worker_evidences[e]}\n"
yield AgentOutput(
text="",
agent_type=self.agent_type,
status="thinking",
intermediate_steps=[{"worker_log": current_progress}],
)
# Solve
solver_response = ""
for solver_output in self.solver.stream(instruction, worker_log):
solver_output_text = solver_output.text
solver_response += solver_output_text
yield AgentOutput(
text=solver_output_text,
agent_type=self.agent_type,
status="thinking",
)
if use_citation:
citation_pipeline = CitationPipeline(llm=self.solver_llm)
citation = citation_pipeline.invoke(
context=worker_log, question=instruction
)
else:
citation = None
return AgentOutput(
text="",
agent_type=self.agent_type,
status="finished",
total_tokens=total_token,
total_cost=total_cost,
citation=citation,
metadata={"citation": citation, "worker_log": worker_log},
)

View File

@@ -81,3 +81,26 @@ class Planner(BaseComponent):
raise ValueError("Planner failed to retrieve response from LLM") from e
return response
def stream(self, instruction: str, output: BaseScratchPad = BaseScratchPad()):
response = None
output.info("Running Planner")
prompt = self._compose_prompt(instruction)
output.debug(f"Prompt: {prompt}")
response = ""
try:
for text in self.model.stream(prompt):
response += text
yield text
self.log_progress(".planner", response=response)
output.info("Planner run successful.")
except NotImplementedError:
print("Streaming is not supported, falling back to normal run")
response = self.model(prompt)
yield response
except ValueError as e:
output.error("Planner failed to retrieve response from LLM")
raise ValueError("Planner failed to retrieve response from LLM") from e
return response

View File

@@ -81,7 +81,7 @@ And so on...
zero_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task.
Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
##My Plans and Evidences##
{plan_evidence}
@@ -99,7 +99,7 @@ So, <your conclusion>.
few_shot_solver_prompt = PromptTemplate(
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
Your task is to briefly summarize each step, then make a short final conclusion for your task.
Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
##My Plans and Evidences##
{plan_evidence}

View File

@@ -11,6 +11,7 @@ class Solver(BaseComponent):
model: BaseLLM
prompt_template: Optional[PromptTemplate] = None
examples: Optional[Union[str, List[str]]] = None
output_lang: str = "English"
def _compose_fewshot_prompt(self) -> str:
if self.examples is None:
@@ -20,7 +21,7 @@ class Solver(BaseComponent):
else:
return "\n\n".join([e.strip("\n") for e in self.examples])
def _compose_prompt(self, instruction, plan_evidence) -> str:
def _compose_prompt(self, instruction, plan_evidence, output_lang) -> str:
"""
Compose the prompt from template, plan&evidence, examples and instruction.
"""
@@ -28,20 +29,28 @@ class Solver(BaseComponent):
if self.prompt_template is not None:
if "fewshot" in self.prompt_template.placeholders:
return self.prompt_template.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
plan_evidence=plan_evidence,
fewshot=fewshot,
task=instruction,
lang=output_lang,
)
else:
return self.prompt_template.populate(
plan_evidence=plan_evidence, task=instruction
plan_evidence=plan_evidence, task=instruction, lang=output_lang
)
else:
if self.examples is not None:
return few_shot_solver_prompt.populate(
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
plan_evidence=plan_evidence,
fewshot=fewshot,
task=instruction,
lang=output_lang,
)
else:
return zero_shot_solver_prompt.populate(
plan_evidence=plan_evidence, task=instruction
plan_evidence=plan_evidence,
task=instruction,
lang=output_lang,
)
def run(
@@ -54,7 +63,7 @@ class Solver(BaseComponent):
output.info("Running Solver")
output.debug(f"Instruction: {instruction}")
output.debug(f"Plan Evidence: {plan_evidence}")
prompt = self._compose_prompt(instruction, plan_evidence)
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
output.debug(f"Prompt: {prompt}")
try:
response = self.model(prompt)
@@ -63,3 +72,28 @@ class Solver(BaseComponent):
output.error("Solver failed to retrieve response from LLM")
return response
def stream(
self,
instruction: str,
plan_evidence: str,
output: BaseScratchPad = BaseScratchPad(),
) -> Any:
response = ""
output.info("Running Solver")
output.debug(f"Instruction: {instruction}")
output.debug(f"Plan Evidence: {plan_evidence}")
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
output.debug(f"Prompt: {prompt}")
try:
for text in self.model.stream(prompt):
response += text.text
yield text
output.info("Planner run successful.")
except NotImplementedError:
response = self.model(prompt).text
output.info("Solver run successful.")
except ValueError:
output.error("Solver failed to retrieve response from LLM")
return response

View File

@@ -1,4 +1,5 @@
from typing import AnyStr, Optional, Type
from urllib.error import HTTPError
from langchain.utilities import SerpAPIWrapper
from pydantic import BaseModel, Field
@@ -26,12 +27,17 @@ class GoogleSearchTool(BaseTool):
"install googlesearch using `pip3 install googlesearch-python` to "
"use this tool"
)
output = ""
search_results = search(query, advanced=True)
if search_results:
output = "\n".join(
"{} {}".format(item.title, item.description) for item in search_results
)
try:
output = ""
search_results = search(query, advanced=True)
if search_results:
output = "\n".join(
"{} {}".format(item.title, item.description)
for item in search_results
)
except HTTPError:
output = "No evidence found."
return output

View File

@@ -2,9 +2,10 @@ from typing import AnyStr, Optional, Type
from pydantic import BaseModel, Field
from kotaemon.agents.tools.base import ToolException
from kotaemon.llms import BaseLLM
from .base import BaseTool, ToolException
from .base import BaseTool
class LLMArgs(BaseModel):