Feat/Add ReAct and ReWOO Reasoning Pipelines (#43)
* Add ReactAgentPipeline by wrapping the ReactAgent * Implement stream processing for ReactAgentPipeline and RewooAgentPipeline * Fix highlight_citation in Rewoo and remove highlight_citation from React * Fix importing ktem.llms inside kotaemon * fix: Change Rewoo::solver's output to LLMInterface instead of plain text * Add more user_settings to the RewooAgentPipeline * Fix LLMTool * Add more user_settings to the ReactAgentPipeline * Minor fix * Stream the react agent immediately * Yield the Rewoo progress to info panel * Hide the agent in flowsettings * Remove redundant comments --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -253,5 +253,6 @@ class AgentOutput(LLMInterface):
|
||||
text: str
|
||||
type: str = "agent"
|
||||
agent_type: AgentType
|
||||
status: Literal["finished", "stopped", "failed"]
|
||||
status: Literal["thinking", "finished", "stopped", "failed"]
|
||||
error: Optional[str] = None
|
||||
intermediate_steps: Optional[list] = None
|
||||
|
@@ -1,11 +1,15 @@
|
||||
import logging
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import tiktoken
|
||||
|
||||
from kotaemon.agents.base import BaseAgent, BaseLLM
|
||||
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.base import Param
|
||||
from kotaemon.base import Document, Param
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
@@ -22,6 +26,7 @@ class ReactAgent(BaseAgent):
|
||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
output_lang: str = "English"
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||
)
|
||||
@@ -32,8 +37,18 @@ class ReactAgent(BaseAgent):
|
||||
default_callback=lambda _: [],
|
||||
help="List of AgentAction and observation (tool) output",
|
||||
)
|
||||
max_iterations: int = 10
|
||||
max_iterations: int = 5
|
||||
strict_decode: bool = False
|
||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||
chunk_size=800,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
|
||||
def _compose_plugin_description(self) -> str:
|
||||
"""
|
||||
@@ -119,6 +134,7 @@ class ReactAgent(BaseAgent):
|
||||
agent_scratchpad=agent_scratchpad,
|
||||
tool_description=tool_description,
|
||||
tool_names=tool_names,
|
||||
lang=self.output_lang,
|
||||
)
|
||||
|
||||
def _format_function_map(self) -> dict[str, BaseTool]:
|
||||
@@ -133,6 +149,20 @@ class ReactAgent(BaseAgent):
|
||||
function_map[plugin.name] = plugin
|
||||
return function_map
|
||||
|
||||
def _trim(self, text: str) -> str:
|
||||
"""
|
||||
Trim the text to the maximum token length.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
texts = self.trim_func([Document(text=text)])
|
||||
elif isinstance(text, Document):
|
||||
texts = self.trim_func([text])
|
||||
else:
|
||||
raise ValueError("Invalid text type to trim")
|
||||
trim_text = texts[0].text
|
||||
logging.info(f"len (trimmed): {len(trim_text)}")
|
||||
return trim_text
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clear and reset the agent.
|
||||
@@ -183,6 +213,11 @@ class ReactAgent(BaseAgent):
|
||||
logging.info(f"Action: {action_name}")
|
||||
logging.info(f"Tool Input: {tool_input}")
|
||||
result = self._format_function_map()[action_name](tool_input)
|
||||
|
||||
# trim the worker output to 1000 tokens, as we are appending
|
||||
# all workers' logs and it can exceed the token limit if we
|
||||
# don't limit each. Fix this number regarding to the LLM capacity.
|
||||
result = self._trim(result)
|
||||
logging.info(f"Result: {result}")
|
||||
|
||||
self.intermediate_steps.append((action_step, result))
|
||||
@@ -202,3 +237,100 @@ class ReactAgent(BaseAgent):
|
||||
intermediate_steps=self.intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
def stream(self, instruction, max_iterations=None):
|
||||
"""
|
||||
Stream the agent with the given instruction.
|
||||
|
||||
Args:
|
||||
instruction: Instruction to run the agent with.
|
||||
max_iterations: Maximum number of iterations
|
||||
of reasoning steps, defaults to 10.
|
||||
|
||||
Return:
|
||||
AgentOutput object.
|
||||
"""
|
||||
if not max_iterations:
|
||||
max_iterations = self.max_iterations
|
||||
assert max_iterations > 0
|
||||
|
||||
self.clear()
|
||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
print(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
status = "failed"
|
||||
response_text = None
|
||||
|
||||
for step_count in range(1, max_iterations + 1):
|
||||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
print(f"Prompt: {prompt}")
|
||||
response = self.llm(
|
||||
prompt, stop=["Observation:"]
|
||||
) # TODO: could cause bugs if llm doesn't have `stop` as a parameter
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
print(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
if action_step is None:
|
||||
raise ValueError("Invalid action")
|
||||
is_finished_chain = isinstance(action_step, AgentFinish)
|
||||
if is_finished_chain:
|
||||
result = response_text
|
||||
if "Final Answer:" in response_text:
|
||||
result = response_text.split("Final Answer:")[-1].strip()
|
||||
else:
|
||||
assert isinstance(action_step, AgentAction)
|
||||
action_name = action_step.tool
|
||||
tool_input = action_step.tool_input
|
||||
logging.info(f"Action: {action_name}")
|
||||
print(f"Action: {action_name}")
|
||||
logging.info(f"Tool Input: {tool_input}")
|
||||
print(f"Tool Input: {tool_input}")
|
||||
result = self._format_function_map()[action_name](tool_input)
|
||||
|
||||
# trim the worker output to 1000 tokens, as we are appending
|
||||
# all workers' logs and it can exceed the token limit if we
|
||||
# don't limit each. Fix this number regarding to the LLM capacity.
|
||||
result = self._trim(result)
|
||||
logging.info(f"Result: {result}")
|
||||
print(f"Result: {result}")
|
||||
|
||||
self.intermediate_steps.append((action_step, result))
|
||||
if is_finished_chain:
|
||||
logging.info(f"Finished after {step_count} steps.")
|
||||
status = "finished"
|
||||
yield AgentOutput(
|
||||
text=result,
|
||||
agent_type=self.agent_type,
|
||||
status=status,
|
||||
intermediate_steps=self.intermediate_steps[-1],
|
||||
)
|
||||
break
|
||||
else:
|
||||
yield AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status="thinking",
|
||||
intermediate_steps=self.intermediate_steps[-1],
|
||||
)
|
||||
|
||||
else:
|
||||
status = "stopped"
|
||||
yield AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status=status,
|
||||
intermediate_steps=self.intermediate_steps[-1],
|
||||
)
|
||||
|
||||
return AgentOutput(
|
||||
text=response_text,
|
||||
agent_type=self.agent_type,
|
||||
status=status,
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
intermediate_steps=self.intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
@@ -3,7 +3,7 @@
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
zero_shot_react_prompt = PromptTemplate(
|
||||
template="""Answer the following questions as best you can. You have access to the following tools:
|
||||
template="""Answer the following questions as best you can. Give answer in {lang}. You have access to the following tools:
|
||||
{tool_description}
|
||||
Use the following format:
|
||||
|
||||
@@ -12,7 +12,7 @@ Thought: you should always think about what to do
|
||||
|
||||
Action: the action to take, should be one of [{tool_names}]
|
||||
|
||||
Action Input: the input to the action
|
||||
Action Input: the input to the action, should be different from the action input of the same action in previous steps.
|
||||
|
||||
Observation: the result of the action
|
||||
|
||||
|
@@ -1,14 +1,18 @@
|
||||
import logging
|
||||
import re
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from kotaemon.agents.base import BaseAgent
|
||||
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.agents.utils import get_plugin_response_content
|
||||
from kotaemon.base import Node, Param
|
||||
from kotaemon.indices.qa import CitationPipeline
|
||||
from kotaemon.base import Document, Node, Param
|
||||
from kotaemon.indices.qa.citation import CitationPipeline
|
||||
from kotaemon.indices.splitters import TokenSplitter
|
||||
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||
|
||||
from .planner import Planner
|
||||
@@ -22,6 +26,7 @@ class RewooAgent(BaseAgent):
|
||||
name: str = "RewooAgent"
|
||||
agent_type: AgentType = AgentType.rewoo
|
||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||
output_lang: str = "English"
|
||||
planner_llm: BaseLLM
|
||||
solver_llm: BaseLLM
|
||||
prompt_template: dict[str, PromptTemplate] = Param(
|
||||
@@ -34,6 +39,16 @@ class RewooAgent(BaseAgent):
|
||||
examples: dict[str, str | list[str]] = Param(
|
||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||
)
|
||||
trim_func: TokenSplitter = TokenSplitter.withx(
|
||||
chunk_size=3000,
|
||||
chunk_overlap=0,
|
||||
separator=" ",
|
||||
tokenizer=partial(
|
||||
tiktoken.encoding_for_model("gpt-3.5-turbo").encode,
|
||||
allowed_special=set(),
|
||||
disallowed_special="all",
|
||||
),
|
||||
)
|
||||
|
||||
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||
def planner(self):
|
||||
@@ -50,6 +65,7 @@ class RewooAgent(BaseAgent):
|
||||
model=self.solver_llm,
|
||||
prompt_template=self.prompt_template.get("Solver", None),
|
||||
examples=self.examples.get("Solver", None),
|
||||
output_lang=self.output_lang,
|
||||
)
|
||||
|
||||
def _parse_plan_map(
|
||||
@@ -159,8 +175,13 @@ class RewooAgent(BaseAgent):
|
||||
tool_input = tool_input[:-1]
|
||||
# find variables in input and replace with previous evidences
|
||||
for var in re.findall(r"#E\d+", tool_input):
|
||||
print("Tool input: ", tool_input)
|
||||
print("Var: ", var)
|
||||
print("Worker evidences: ", worker_evidences)
|
||||
if var in worker_evidences:
|
||||
tool_input = tool_input.replace(var, worker_evidences.get(var, ""))
|
||||
tool_input = tool_input.replace(
|
||||
var, worker_evidences.get(var, "") or ""
|
||||
)
|
||||
try:
|
||||
selected_plugin = self._find_plugin(tool)
|
||||
if selected_plugin is None:
|
||||
@@ -216,7 +237,7 @@ class RewooAgent(BaseAgent):
|
||||
resp = r.result()
|
||||
plugin_cost += resp["plugin_cost"]
|
||||
plugin_token += resp["plugin_token"]
|
||||
worker_evidences[resp["e"]] = resp["evidence"]
|
||||
worker_evidences[resp["e"]] = self._trim_evidence(resp["evidence"])
|
||||
output.done()
|
||||
|
||||
return worker_evidences, plugin_cost, plugin_token
|
||||
@@ -226,6 +247,13 @@ class RewooAgent(BaseAgent):
|
||||
if p.name == name:
|
||||
return p
|
||||
|
||||
def _trim_evidence(self, evidence: str):
|
||||
if evidence:
|
||||
texts = self.trim_func([Document(text=evidence)])
|
||||
evidence = texts[0].text
|
||||
logging.info(f"len (trimmed): {len(evidence)}")
|
||||
return evidence
|
||||
|
||||
@BaseAgent.safeguard_run
|
||||
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
||||
"""
|
||||
@@ -269,5 +297,69 @@ class RewooAgent(BaseAgent):
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
metadata={"citation": citation},
|
||||
metadata={"citation": citation, "worker_log": worker_log},
|
||||
)
|
||||
|
||||
def stream(self, instruction: str, use_citation: bool = False):
|
||||
"""
|
||||
Stream the agent with a given instruction.
|
||||
"""
|
||||
logging.info(f"Streaming {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
|
||||
# Plan
|
||||
planner_output = self.planner(instruction)
|
||||
planner_text_output = planner_output.text
|
||||
plan_to_es, plans = self._parse_plan_map(planner_text_output)
|
||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||
planner_text_output
|
||||
)
|
||||
|
||||
print("Planner output:", planner_text_output)
|
||||
# Work
|
||||
worker_evidences, plugin_cost, plugin_token = self._get_worker_evidence(
|
||||
planner_evidences, evidence_level
|
||||
)
|
||||
worker_log = ""
|
||||
for plan in plan_to_es:
|
||||
worker_log += f"{plan}: {plans[plan]}\n"
|
||||
current_progress = f"{plan}: {plans[plan]}\n"
|
||||
for e in plan_to_es[plan]:
|
||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||
current_progress += f"{e}: {worker_evidences[e]}\n"
|
||||
|
||||
yield AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status="thinking",
|
||||
intermediate_steps=[{"worker_log": current_progress}],
|
||||
)
|
||||
|
||||
# Solve
|
||||
solver_response = ""
|
||||
for solver_output in self.solver.stream(instruction, worker_log):
|
||||
solver_output_text = solver_output.text
|
||||
solver_response += solver_output_text
|
||||
yield AgentOutput(
|
||||
text=solver_output_text,
|
||||
agent_type=self.agent_type,
|
||||
status="thinking",
|
||||
)
|
||||
if use_citation:
|
||||
citation_pipeline = CitationPipeline(llm=self.solver_llm)
|
||||
citation = citation_pipeline.invoke(
|
||||
context=worker_log, question=instruction
|
||||
)
|
||||
else:
|
||||
citation = None
|
||||
|
||||
return AgentOutput(
|
||||
text="",
|
||||
agent_type=self.agent_type,
|
||||
status="finished",
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
citation=citation,
|
||||
metadata={"citation": citation, "worker_log": worker_log},
|
||||
)
|
||||
|
@@ -81,3 +81,26 @@ class Planner(BaseComponent):
|
||||
raise ValueError("Planner failed to retrieve response from LLM") from e
|
||||
|
||||
return response
|
||||
|
||||
def stream(self, instruction: str, output: BaseScratchPad = BaseScratchPad()):
|
||||
response = None
|
||||
output.info("Running Planner")
|
||||
prompt = self._compose_prompt(instruction)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
|
||||
response = ""
|
||||
try:
|
||||
for text in self.model.stream(prompt):
|
||||
response += text
|
||||
yield text
|
||||
self.log_progress(".planner", response=response)
|
||||
output.info("Planner run successful.")
|
||||
except NotImplementedError:
|
||||
print("Streaming is not supported, falling back to normal run")
|
||||
response = self.model(prompt)
|
||||
yield response
|
||||
except ValueError as e:
|
||||
output.error("Planner failed to retrieve response from LLM")
|
||||
raise ValueError("Planner failed to retrieve response from LLM") from e
|
||||
|
||||
return response
|
||||
|
@@ -81,7 +81,7 @@ And so on...
|
||||
|
||||
zero_shot_solver_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans(#Plan) and evidences(#E) that could be helpful.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
|
||||
|
||||
##My Plans and Evidences##
|
||||
{plan_evidence}
|
||||
@@ -99,7 +99,7 @@ So, <your conclusion>.
|
||||
|
||||
few_shot_solver_prompt = PromptTemplate(
|
||||
template="""You are an AI agent who solves a problem with my assistance. I will provide step-by-step plans and evidences that could be helpful.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task.
|
||||
Your task is to briefly summarize each step, then make a short final conclusion for your task. Give answer in {lang}.
|
||||
|
||||
##My Plans and Evidences##
|
||||
{plan_evidence}
|
||||
|
@@ -11,6 +11,7 @@ class Solver(BaseComponent):
|
||||
model: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
examples: Optional[Union[str, List[str]]] = None
|
||||
output_lang: str = "English"
|
||||
|
||||
def _compose_fewshot_prompt(self) -> str:
|
||||
if self.examples is None:
|
||||
@@ -20,7 +21,7 @@ class Solver(BaseComponent):
|
||||
else:
|
||||
return "\n\n".join([e.strip("\n") for e in self.examples])
|
||||
|
||||
def _compose_prompt(self, instruction, plan_evidence) -> str:
|
||||
def _compose_prompt(self, instruction, plan_evidence, output_lang) -> str:
|
||||
"""
|
||||
Compose the prompt from template, plan&evidence, examples and instruction.
|
||||
"""
|
||||
@@ -28,20 +29,28 @@ class Solver(BaseComponent):
|
||||
if self.prompt_template is not None:
|
||||
if "fewshot" in self.prompt_template.placeholders:
|
||||
return self.prompt_template.populate(
|
||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||
plan_evidence=plan_evidence,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
lang=output_lang,
|
||||
)
|
||||
else:
|
||||
return self.prompt_template.populate(
|
||||
plan_evidence=plan_evidence, task=instruction
|
||||
plan_evidence=plan_evidence, task=instruction, lang=output_lang
|
||||
)
|
||||
else:
|
||||
if self.examples is not None:
|
||||
return few_shot_solver_prompt.populate(
|
||||
plan_evidence=plan_evidence, fewshot=fewshot, task=instruction
|
||||
plan_evidence=plan_evidence,
|
||||
fewshot=fewshot,
|
||||
task=instruction,
|
||||
lang=output_lang,
|
||||
)
|
||||
else:
|
||||
return zero_shot_solver_prompt.populate(
|
||||
plan_evidence=plan_evidence, task=instruction
|
||||
plan_evidence=plan_evidence,
|
||||
task=instruction,
|
||||
lang=output_lang,
|
||||
)
|
||||
|
||||
def run(
|
||||
@@ -54,7 +63,7 @@ class Solver(BaseComponent):
|
||||
output.info("Running Solver")
|
||||
output.debug(f"Instruction: {instruction}")
|
||||
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||
prompt = self._compose_prompt(instruction, plan_evidence)
|
||||
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
try:
|
||||
response = self.model(prompt)
|
||||
@@ -63,3 +72,28 @@ class Solver(BaseComponent):
|
||||
output.error("Solver failed to retrieve response from LLM")
|
||||
|
||||
return response
|
||||
|
||||
def stream(
|
||||
self,
|
||||
instruction: str,
|
||||
plan_evidence: str,
|
||||
output: BaseScratchPad = BaseScratchPad(),
|
||||
) -> Any:
|
||||
response = ""
|
||||
output.info("Running Solver")
|
||||
output.debug(f"Instruction: {instruction}")
|
||||
output.debug(f"Plan Evidence: {plan_evidence}")
|
||||
prompt = self._compose_prompt(instruction, plan_evidence, self.output_lang)
|
||||
output.debug(f"Prompt: {prompt}")
|
||||
try:
|
||||
for text in self.model.stream(prompt):
|
||||
response += text.text
|
||||
yield text
|
||||
output.info("Planner run successful.")
|
||||
except NotImplementedError:
|
||||
response = self.model(prompt).text
|
||||
output.info("Solver run successful.")
|
||||
except ValueError:
|
||||
output.error("Solver failed to retrieve response from LLM")
|
||||
|
||||
return response
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from typing import AnyStr, Optional, Type
|
||||
from urllib.error import HTTPError
|
||||
|
||||
from langchain.utilities import SerpAPIWrapper
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -26,12 +27,17 @@ class GoogleSearchTool(BaseTool):
|
||||
"install googlesearch using `pip3 install googlesearch-python` to "
|
||||
"use this tool"
|
||||
)
|
||||
output = ""
|
||||
search_results = search(query, advanced=True)
|
||||
if search_results:
|
||||
output = "\n".join(
|
||||
"{} {}".format(item.title, item.description) for item in search_results
|
||||
)
|
||||
|
||||
try:
|
||||
output = ""
|
||||
search_results = search(query, advanced=True)
|
||||
if search_results:
|
||||
output = "\n".join(
|
||||
"{} {}".format(item.title, item.description)
|
||||
for item in search_results
|
||||
)
|
||||
except HTTPError:
|
||||
output = "No evidence found."
|
||||
|
||||
return output
|
||||
|
||||
|
@@ -2,9 +2,10 @@ from typing import AnyStr, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from kotaemon.agents.tools.base import ToolException
|
||||
from kotaemon.llms import BaseLLM
|
||||
|
||||
from .base import BaseTool, ToolException
|
||||
from .base import BaseTool
|
||||
|
||||
|
||||
class LLMArgs(BaseModel):
|
||||
|
Reference in New Issue
Block a user