refractor agents (#100)
* refractor agents * minor cosmetic, add terminal ui for cli * pump to 0.3.4 * Add temporary path * fix unclose files in tests --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
@@ -4,12 +4,11 @@ from typing import Optional
|
||||
|
||||
from theflow import Param
|
||||
|
||||
from kotaemon.base.schema import Document
|
||||
from kotaemon.agents.base import BaseAgent, BaseLLM
|
||||
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
||||
from kotaemon.agents.tools import BaseTool
|
||||
from kotaemon.llms import PromptTemplate
|
||||
|
||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||
from ..output.base import AgentAction, AgentFinish
|
||||
|
||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||
|
||||
|
||||
@@ -22,7 +21,7 @@ class ReactAgent(BaseAgent):
|
||||
name: str = "ReactAgent"
|
||||
agent_type: AgentType = AgentType.react
|
||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||
llm: BaseLLM | dict[str, BaseLLM]
|
||||
llm: BaseLLM
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
plugins: list[BaseTool] = Param(
|
||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||
@@ -34,7 +33,7 @@ class ReactAgent(BaseAgent):
|
||||
default_callback=lambda _: [],
|
||||
help="List of AgentAction and observation (tool) output",
|
||||
)
|
||||
max_iterations = 10
|
||||
max_iterations: int = 10
|
||||
strict_decode: bool = False
|
||||
|
||||
def _compose_plugin_description(self) -> str:
|
||||
@@ -141,7 +140,7 @@ class ReactAgent(BaseAgent):
|
||||
"""
|
||||
self.intermediate_steps = []
|
||||
|
||||
def run(self, instruction, max_iterations=None):
|
||||
def run(self, instruction, max_iterations=None) -> AgentOutput:
|
||||
"""
|
||||
Run the agent with the given instruction.
|
||||
|
||||
@@ -161,11 +160,15 @@ class ReactAgent(BaseAgent):
|
||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||
total_cost = 0.0
|
||||
total_token = 0
|
||||
status = "failed"
|
||||
response_text = None
|
||||
|
||||
for _ in range(max_iterations):
|
||||
for step_count in range(1, max_iterations + 1):
|
||||
prompt = self._compose_prompt(instruction)
|
||||
logging.info(f"Prompt: {prompt}")
|
||||
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
|
||||
response = self.llm(
|
||||
prompt, stop=["Observation:"]
|
||||
) # could cause bugs if llm doesn't have `stop` as a parameter
|
||||
response_text = response.text
|
||||
logging.info(f"Response: {response_text}")
|
||||
action_step = self._parse_output(response_text)
|
||||
@@ -185,13 +188,18 @@ class ReactAgent(BaseAgent):
|
||||
|
||||
self.intermediate_steps.append((action_step, result))
|
||||
if is_finished_chain:
|
||||
logging.info(f"Finished after {step_count} steps.")
|
||||
status = "finished"
|
||||
break
|
||||
else:
|
||||
status = "stopped"
|
||||
|
||||
return Document(
|
||||
return AgentOutput(
|
||||
text=response_text,
|
||||
metadata={
|
||||
"agent": "react",
|
||||
"cost": total_cost,
|
||||
"usage": total_token,
|
||||
},
|
||||
agent_type=self.agent_type,
|
||||
status=status,
|
||||
total_tokens=total_token,
|
||||
total_cost=total_cost,
|
||||
intermediate_steps=self.intermediate_steps,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
Reference in New Issue
Block a user