refractor agents (#100)

* refractor agents

* minor cosmetic, add terminal ui for cli

* pump to 0.3.4

* Add temporary path

* fix unclose files in tests

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
ian_Cin 2023-12-06 17:06:29 +07:00 committed by GitHub
parent d9e925eb75
commit 797df5a69c
21 changed files with 281 additions and 228 deletions

View File

@ -6,6 +6,9 @@ on:
push: push:
branches: [main] branches: [main]
env:
THEFLOW_TEMP_PATH: ./tmp
jobs: jobs:
unit-test: unit-test:
if: ${{ !cancelled() }} if: ${{ !cancelled() }}

View File

@ -1,6 +1,25 @@
from .base import AgentType, BaseAgent from .base import BaseAgent
from .langchain import LangchainAgent from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad
from .langchain_based import LangchainAgent
from .react.agent import ReactAgent from .react.agent import ReactAgent
from .rewoo.agent import RewooAgent from .rewoo.agent import RewooAgent
from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"] __all__ = [
# agent
"BaseAgent",
"ReactAgent",
"RewooAgent",
"LangchainAgent",
# tool
"BaseTool",
"ComponentTool",
"GoogleSearchTool",
"WikipediaTool",
"LLMTool",
# io
"AgentType",
"AgentOutput",
"AgentFinish",
"BaseScratchPad",
]

View File

@ -1,45 +1,13 @@
from enum import Enum
from typing import Optional, Union from typing import Optional, Union
from theflow import Node, Param from theflow import Node, Param
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from .io import AgentOutput, AgentType
from .tools import BaseTool from .tools import BaseTool
BaseLLM = Union[ChatLLM, LLM]
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
@staticmethod
def get_agent_class(_type: "AgentType"):
"""
Get agent class from agent type.
:param _type: agent type
:return: agent class
"""
if _type == AgentType.rewoo:
from .rewoo.agent import RewooAgent
return RewooAgent
else:
raise ValueError(f"Unknown agent type: {_type}")
class BaseAgent(BaseComponent): class BaseAgent(BaseComponent):
"""Define base agent interface""" """Define base agent interface"""
@ -47,13 +15,17 @@ class BaseAgent(BaseComponent):
name: str = Param(help="Name of the agent.") name: str = Param(help="Name of the agent.")
agent_type: AgentType = Param(help="Agent type, must be one of AgentType") agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
description: str = Param( description: str = Param(
help="Description used to tell the model how/when/why to use the agent. " help=(
"You can provide few-shot examples as a part of the description. This will be " "Description used to tell the model how/when/why to use the agent. You can"
" provide few-shot examples as a part of the description. This will be"
" input to the prompt of LLM." " input to the prompt of LLM."
) )
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node( )
help="Specify LLM to be used in the model, cam be a dict to supply different " llm: Optional[BaseLLM] = Node(
"LLMs to multiple purposes in the agent" help=(
"LLM to be used for the agent (optional). LLM must implement BaseLLM"
" interface."
)
) )
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param( prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
help="A prompt template or a dict to supply different prompt to the agent" help="A prompt template or a dict to supply different prompt to the agent"
@ -63,6 +35,25 @@ class BaseAgent(BaseComponent):
help="List of plugins / tools to be used in the agent", help="List of plugins / tools to be used in the agent",
) )
@staticmethod
def safeguard_run(run_func, *args, **kwargs):
def wrapper(self, *args, **kwargs):
try:
return run_func(self, *args, **kwargs)
except Exception as e:
return AgentOutput(
text="",
agent_type=self.agent_type,
status="failed",
error=str(e),
)
return wrapper
def add_tools(self, tools: list[BaseTool]) -> None: def add_tools(self, tools: list[BaseTool]) -> None:
"""Helper method to add tools and update agent state if needed""" """Helper method to add tools and update agent state if needed"""
self.plugins.extend(tools) self.plugins.extend(tools)
def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]:
"""Run the component."""
raise NotImplementedError()

View File

@ -0,0 +1,3 @@
from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad
__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"]

View File

@ -2,7 +2,12 @@ import json
import logging import logging
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, NamedTuple, Union from enum import Enum
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
from pydantic import Extra
from kotaemon.base import LLMInterface
def check_log(): def check_log():
@ -14,6 +19,20 @@ def check_log():
return os.environ.get("LOG_PATH", None) is not None return os.environ.get("LOG_PATH", None) is not None
class AgentType(Enum):
"""
Enumerated type for agent types.
"""
openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react"
rewoo = "rewoo"
vanilla = "vanilla"
class BaseScratchPad: class BaseScratchPad:
""" """
Base class for output handlers. Base class for output handlers.
@ -217,3 +236,20 @@ class AgentFinish(NamedTuple):
return_values: dict return_values: dict
log: str log: str
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
"""Output from an agent.
Args:
text: The text output from the agent.
agent_type: The type of agent.
status: The status after executing the agent.
error: The error message if any.
"""
text: str
type: str = "agent"
agent_type: AgentType
status: Literal["finished", "stopped", "failed"]
error: Optional[str] = None

View File

@ -4,12 +4,11 @@ from langchain.agents import AgentType as LCAgentType
from langchain.agents import initialize_agent from langchain.agents import initialize_agent
from langchain.agents.agent import AgentExecutor as LCAgentExecutor from langchain.agents.agent import AgentExecutor as LCAgentExecutor
from kotaemon.agents.tools import BaseTool from kotaemon.llms import LLM, ChatLLM
from kotaemon.base.schema import Document
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from .base import AgentType, BaseAgent from .base import BaseAgent
from .io import AgentOutput, AgentType
from .tools import BaseTool
class LangchainAgent(BaseAgent): class LangchainAgent(BaseAgent):
@ -54,7 +53,9 @@ class LangchainAgent(BaseAgent):
# reinit Langchain AgentExecutor # reinit Langchain AgentExecutor
self.agent = initialize_agent( self.agent = initialize_agent(
langchain_plugins, langchain_plugins,
self.llm._obj, # TODO: could cause bugs for non-langchain llms
# related to https://github.com/Cinnamon/kotaemon/issues/73
self.llm._obj, # type: ignore
agent=self.AGENT_TYPE_MAP[self.agent_type], agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True, handle_parsing_errors=True,
verbose=True, verbose=True,
@ -65,17 +66,16 @@ class LangchainAgent(BaseAgent):
self.update_agent_tools() self.update_agent_tools()
return return
def run(self, instruction: str) -> Document: def run(self, instruction: str) -> AgentOutput:
assert ( assert (
self.agent is not None self.agent is not None
), "Lanchain AgentExecutor is not correclty initialized" ), "Lanchain AgentExecutor is not correclty initialized"
# Langchain AgentExecutor call # Langchain AgentExecutor call
output = self.agent(instruction)["output"] output = self.agent(instruction)["output"]
return Document(
return AgentOutput(
text=output, text=output,
metadata={ agent_type=self.agent_type,
"agent": "langchain", status="finished",
"cost": 0.0,
"usage": 0,
},
) )

View File

@ -4,12 +4,11 @@ from typing import Optional
from theflow import Param from theflow import Param
from kotaemon.base.schema import Document from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import AgentAction, AgentFinish
FINAL_ANSWER_ACTION = "Final Answer:" FINAL_ANSWER_ACTION = "Final Answer:"
@ -22,7 +21,7 @@ class ReactAgent(BaseAgent):
name: str = "ReactAgent" name: str = "ReactAgent"
agent_type: AgentType = AgentType.react agent_type: AgentType = AgentType.react
description: str = "ReactAgent for answering multi-step reasoning questions" description: str = "ReactAgent for answering multi-step reasoning questions"
llm: BaseLLM | dict[str, BaseLLM] llm: BaseLLM
prompt_template: Optional[PromptTemplate] = None prompt_template: Optional[PromptTemplate] = None
plugins: list[BaseTool] = Param( plugins: list[BaseTool] = Param(
default_callback=lambda _: [], help="List of tools to be used in the agent. " default_callback=lambda _: [], help="List of tools to be used in the agent. "
@ -34,7 +33,7 @@ class ReactAgent(BaseAgent):
default_callback=lambda _: [], default_callback=lambda _: [],
help="List of AgentAction and observation (tool) output", help="List of AgentAction and observation (tool) output",
) )
max_iterations = 10 max_iterations: int = 10
strict_decode: bool = False strict_decode: bool = False
def _compose_plugin_description(self) -> str: def _compose_plugin_description(self) -> str:
@ -141,7 +140,7 @@ class ReactAgent(BaseAgent):
""" """
self.intermediate_steps = [] self.intermediate_steps = []
def run(self, instruction, max_iterations=None): def run(self, instruction, max_iterations=None) -> AgentOutput:
""" """
Run the agent with the given instruction. Run the agent with the given instruction.
@ -161,11 +160,15 @@ class ReactAgent(BaseAgent):
logging.info(f"Running {self.name} with instruction: {instruction}") logging.info(f"Running {self.name} with instruction: {instruction}")
total_cost = 0.0 total_cost = 0.0
total_token = 0 total_token = 0
status = "failed"
response_text = None
for _ in range(max_iterations): for step_count in range(1, max_iterations + 1):
prompt = self._compose_prompt(instruction) prompt = self._compose_prompt(instruction)
logging.info(f"Prompt: {prompt}") logging.info(f"Prompt: {prompt}")
response = self.llm(prompt, stop=["Observation:"]) # type: ignore response = self.llm(
prompt, stop=["Observation:"]
) # could cause bugs if llm doesn't have `stop` as a parameter
response_text = response.text response_text = response.text
logging.info(f"Response: {response_text}") logging.info(f"Response: {response_text}")
action_step = self._parse_output(response_text) action_step = self._parse_output(response_text)
@ -185,13 +188,18 @@ class ReactAgent(BaseAgent):
self.intermediate_steps.append((action_step, result)) self.intermediate_steps.append((action_step, result))
if is_finished_chain: if is_finished_chain:
logging.info(f"Finished after {step_count} steps.")
status = "finished"
break break
else:
status = "stopped"
return Document( return AgentOutput(
text=response_text, text=response_text,
metadata={ agent_type=self.agent_type,
"agent": "react", status=status,
"cost": total_cost, total_tokens=total_token,
"usage": total_token, total_cost=total_cost,
}, intermediate_steps=self.intermediate_steps,
max_iterations=max_iterations,
) )

View File

@ -3,15 +3,15 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
from theflow import Param from theflow import Node, Param
from kotaemon.base.schema import Document from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.indices.qa import CitationPipeline from kotaemon.indices.qa import CitationPipeline
from kotaemon.llms import LLM, ChatLLM, PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from ..utils import get_plugin_response_content
from .planner import Planner from .planner import Planner
from .solver import Solver from .solver import Solver
@ -23,7 +23,8 @@ class RewooAgent(BaseAgent):
name: str = "RewooAgent" name: str = "RewooAgent"
agent_type: AgentType = AgentType.rewoo agent_type: AgentType = AgentType.rewoo
description: str = "RewooAgent for answering multi-step reasoning questions" description: str = "RewooAgent for answering multi-step reasoning questions"
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx} planner_llm: BaseLLM
solver_llm: BaseLLM
prompt_template: dict[str, PromptTemplate] = Param( prompt_template: dict[str, PromptTemplate] = Param(
default_callback=lambda _: {}, default_callback=lambda _: {},
help="A dict to supply different prompt to the agent.", help="A dict to supply different prompt to the agent.",
@ -35,17 +36,22 @@ class RewooAgent(BaseAgent):
default_callback=lambda _: {}, help="Examples to be used in the agent." default_callback=lambda _: {}, help="Examples to be used in the agent."
) )
def _get_llms(self): @Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM): def planner(self):
return {"Planner": self.llm, "Solver": self.llm} return Planner(
elif ( model=self.planner_llm,
isinstance(self.llm, dict) plugins=self.plugins,
and "Planner" in self.llm prompt_template=self.prompt_template.get("Planner", None),
and "Solver" in self.llm examples=self.examples.get("Planner", None),
): )
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
else: @Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.") def solver(self):
return Solver(
model=self.solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
def _parse_plan_map( def _parse_plan_map(
self, planner_response: str self, planner_response: str
@ -76,13 +82,16 @@ class RewooAgent(BaseAgent):
plan_to_es: dict[str, list[str]] = dict() plan_to_es: dict[str, list[str]] = dict()
plans: dict[str, str] = dict() plans: dict[str, str] = dict()
prev_key = ""
for line in valid_chunk: for line in valid_chunk:
if line.startswith("#Plan"): key, description = line.split(":", 1)
plan = line.split(":", 1)[0].strip() key = key.strip()
plans[plan] = line.split(":", 1)[1].strip() if key.startswith("#Plan"):
plan_to_es[plan] = [] plans[key] = description.strip()
elif line.startswith("#E"): plan_to_es[key] = []
plan_to_es[plan].append(line.split(":", 1)[0].strip()) prev_key = key
elif key.startswith("#E"):
plan_to_es[prev_key].append(key)
return plan_to_es, plans return plan_to_es, plans
@ -218,7 +227,8 @@ class RewooAgent(BaseAgent):
if p.name == name: if p.name == name:
return p return p
def run(self, instruction: str, use_citation: bool = False) -> Document: @BaseAgent.safeguard_run
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
""" """
Run the agent with a given instruction. Run the agent with a given instruction.
""" """
@ -226,27 +236,12 @@ class RewooAgent(BaseAgent):
total_cost = 0.0 total_cost = 0.0
total_token = 0 total_token = 0
planner_llm = self._get_llms()["Planner"]
solver_llm = self._get_llms()["Solver"]
planner = Planner(
model=planner_llm,
plugins=self.plugins,
prompt_template=self.prompt_template.get("Planner", None),
examples=self.examples.get("Planner", None),
)
solver = Solver(
model=solver_llm,
prompt_template=self.prompt_template.get("Solver", None),
examples=self.examples.get("Solver", None),
)
# Plan # Plan
planner_output = planner(instruction) planner_output = self.planner(instruction)
plannner_text_output = planner_output.text planner_text_output = planner_output.text
plan_to_es, plans = self._parse_plan_map(plannner_text_output) plan_to_es, plans = self._parse_plan_map(planner_text_output)
planner_evidences, evidence_level = self._parse_planner_evidences( planner_evidences, evidence_level = self._parse_planner_evidences(
plannner_text_output planner_text_output
) )
# Work # Work
@ -260,20 +255,19 @@ class RewooAgent(BaseAgent):
worker_log += f"{e}: {worker_evidences[e]}\n" worker_log += f"{e}: {worker_evidences[e]}\n"
# Solve # Solve
solver_output = solver(instruction, worker_log) solver_output = self.solver(instruction, worker_log)
solver_output_text = solver_output.text solver_output_text = solver_output.text
if use_citation: if use_citation:
citation_pipeline = CitationPipeline(llm=solver_llm) citation_pipeline = CitationPipeline(llm=self.solver_llm)
citation = citation_pipeline(context=worker_log, question=instruction) citation = citation_pipeline(context=worker_log, question=instruction)
else: else:
citation = None citation = None
return Document( return AgentOutput(
text=solver_output_text, text=solver_output_text,
metadata={ agent_type=self.agent_type,
"agent": "react", status="finished",
"cost": total_cost, total_tokens=total_token,
"usage": total_token, total_cost=total_cost,
"citation": citation, citation=citation,
},
) )

View File

@ -1,10 +1,10 @@
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from kotaemon.agents.base import BaseLLM, BaseTool
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
from ..base import BaseLLM, BaseTool
from ..output.base import BaseScratchPad
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt

View File

@ -1,10 +1,9 @@
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union
from kotaemon.agents.io import BaseScratchPad
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.llms import PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate
from ..base import BaseLLM
from ..output.base import BaseScratchPad
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt

View File

@ -11,8 +11,8 @@ class GoogleSearchArgs(BaseModel):
class GoogleSearchTool(BaseTool): class GoogleSearchTool(BaseTool):
name = "google_search" name: str = "google_search"
description = ( description: str = (
"A search engine retrieving top search results as snippets from Google. " "A search engine retrieving top search results as snippets from Google. "
"Input should be a search query." "Input should be a search query."
) )

View File

@ -14,8 +14,8 @@ class LLMArgs(BaseModel):
class LLMTool(BaseTool): class LLMTool(BaseTool):
name = "llm" name: str = "llm"
description = ( description: str = (
"A pretrained LLM like yourself. Useful when you need to act with " "A pretrained LLM like yourself. Useful when you need to act with "
"general world knowledge and common sense. Prioritize it when you " "general world knowledge and common sense. Prioritize it when you "
"are confident in solving the problem " "are confident in solving the problem "

View File

@ -48,8 +48,8 @@ class WikipediaArgs(BaseModel):
class WikipediaTool(BaseTool): class WikipediaTool(BaseTool):
"""Tool that adds the capability to query the Wikipedia API.""" """Tool that adds the capability to query the Wikipedia API."""
name = "wikipedia" name: str = "wikipedia"
description = ( description: str = (
"Search engine from Wikipedia, retrieving relevant wiki page. " "Search engine from Wikipedia, retrieving relevant wiki page. "
"Useful when you need to get holistic knowledge about people, " "Useful when you need to get holistic knowledge about people, "
"places, companies, historical events, or other subjects. " "places, companies, historical events, or other subjects. "

View File

@ -114,6 +114,7 @@ class LLMInterface(AIMessage):
completion_tokens: int = -1 completion_tokens: int = -1
total_tokens: int = -1 total_tokens: int = -1
prompt_tokens: int = -1 prompt_tokens: int = -1
total_cost: float = 0
logits: list[list[float]] = Field(default_factory=list) logits: list[list[float]] = Field(default_factory=list)
messages: list[AIMessage] = Field(default_factory=list) messages: list[AIMessage] = Field(default_factory=list)

View File

@ -2,6 +2,7 @@ import os
import click import click
import yaml import yaml
from trogon import tui
# check if the output is not a .yml file -> raise error # check if the output is not a .yml file -> raise error
@ -14,6 +15,7 @@ def check_config_format(config):
raise ValueError("config must be yaml format.") raise ValueError("config must be yaml format.")
@tui(command="ui", help="Open the terminal UI") # generate the terminal UI
@click.group() @click.group()
def main(): def main():
pass pass
@ -56,8 +58,10 @@ def export(export_path, output):
@click.option( @click.option(
"--username", "--username",
required=False, required=False,
help="Username for the user. If not provided, the promptui will not have " help=(
"authentication.", "Username for the user. If not provided, the promptui will not have "
"authentication."
),
) )
@click.option( @click.option(
"--password", "--password",

View File

@ -1,4 +1,4 @@
from .base import ChatLLM from .base import ChatLLM
from .langchain_based import AzureChatOpenAI from .langchain_based import AzureChatOpenAI, LCChatMixin
__all__ = ["ChatLLM", "AzureChatOpenAI"] __all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]

View File

@ -1,4 +1,4 @@
from .base import LLM from .base import LLM
from .langchain_based import AzureOpenAI, OpenAI from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
__all__ = ["LLM", "OpenAI", "AzureOpenAI"] __all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]

View File

@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies # metadata and dependencies
[project] [project]
name = "kotaemon" name = "kotaemon"
version = "0.3.3" version = "0.3.4"
requires-python = ">= 3.10" requires-python = ">= 3.10"
description = "Kotaemon core library for AI development." description = "Kotaemon core library for AI development."
dependencies = [ dependencies = [
@ -24,6 +24,7 @@ dependencies = [
"cookiecutter", "cookiecutter",
"click", "click",
"pandas", "pandas",
"trogon",
] ]
readme = "README.md" readme = "README.md"
license = { text = "MIT License" } license = { text = "MIT License" }

View File

@ -3,52 +3,34 @@ from unittest.mock import patch
import pytest import pytest
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.agents.base import AgentType from kotaemon.agents import (
from kotaemon.agents.langchain import LangchainAgent AgentType,
from kotaemon.agents.react import ReactAgent BaseTool,
from kotaemon.agents.rewoo import RewooAgent GoogleSearchTool,
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool LangchainAgent,
LLMTool,
ReactAgent,
RewooAgent,
WikipediaTool,
)
from kotaemon.llms import AzureChatOpenAI from kotaemon.llms import AzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
REWOO_VALID_PLAN = (
_openai_chat_completion_responses_rewoo = [
ChatCompletion.parse_obj(
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [
(
"#Plan1: Search for Cinnamon AI company on Google\n" "#Plan1: Search for Cinnamon AI company on Google\n"
"#E1: google_search[Cinnamon AI company]\n" "#E1: google_search[Cinnamon AI company]\n"
"#Plan2: Search for Cinnamon on Wikipedia\n" "#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n" "#E2: wikipedia[Cinnamon]\n"
), )
FINAL_RESPONSE_TEXT, REWOO_INVALID_PLAN = (
] "#E1: google_search[Cinnamon AI company]\n"
] "#Plan2: Search for Cinnamon on Wikipedia\n"
"#E2: wikipedia[Cinnamon]\n"
)
_openai_chat_completion_responses_react = [
ChatCompletion.parse_obj( def generate_chat_completion_obj(text):
return ChatCompletion.parse_obj(
{ {
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion", "object": "chat.completion",
@ -70,6 +52,20 @@ _openai_chat_completion_responses_react = [
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
} }
) )
_openai_chat_completion_responses_rewoo = [
generate_chat_completion_obj(text=text)
for text in [REWOO_VALID_PLAN, FINAL_RESPONSE_TEXT]
]
_openai_chat_completion_responses_rewoo_error = [
generate_chat_completion_obj(text=text)
for text in [REWOO_INVALID_PLAN, FINAL_RESPONSE_TEXT]
]
_openai_chat_completion_responses_react = [
generate_chat_completion_obj(text=text)
for text in [ for text in [
( (
"I don't have prior knowledge about Cinnamon AI company, " "I don't have prior knowledge about Cinnamon AI company, "
@ -91,28 +87,7 @@ _openai_chat_completion_responses_react = [
] ]
_openai_chat_completion_responses_react_langchain_tool = [ _openai_chat_completion_responses_react_langchain_tool = [
ChatCompletion.parse_obj( generate_chat_completion_obj(text=text)
{
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
"object": "chat.completion",
"created": 1692338378,
"model": "gpt-35-turbo",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"finish_reason": "stop",
"message": {
"role": "assistant",
"content": text,
"function_call": None,
"tool_calls": None,
},
}
],
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
}
)
for text in [ for text in [
( (
"I don't have prior knowledge about Cinnamon AI company, " "I don't have prior knowledge about Cinnamon AI company, "
@ -145,6 +120,25 @@ def llm():
) )
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo_error,
)
def test_agent_fail(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert not response
assert response.status == "failed"
@patch( @patch(
"openai.resources.chat.completions.Completions.create", "openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_rewoo, side_effect=_openai_chat_completion_responses_rewoo,
@ -156,7 +150,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
LLMTool(llm=llm), LLMTool(llm=llm),
] ]
agent = RewooAgent(llm=llm, plugins=plugins) agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called() openai_completion.assert_called()

View File

@ -110,7 +110,7 @@ class TestInMemoryVectorStore:
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"]) db.delete(["3"])
db.save(save_path=tmp_path / "test_save_load_delete.json") db.save(save_path=tmp_path / "test_save_load_delete.json")
f = open(tmp_path / "test_save_load_delete.json") with open(tmp_path / "test_save_load_delete.json") as f:
data = json.load(f) data = json.load(f)
assert ( assert (
"1" and "2" in data["text_id_to_ref_doc_id"] "1" and "2" in data["text_id_to_ref_doc_id"]
@ -136,7 +136,7 @@ class TestSimpleFileVectorStore:
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json") db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
db.delete(["3"]) db.delete(["3"])
f = open(tmp_path / "test_save_load_delete.json") with open(tmp_path / "test_save_load_delete.json") as f:
data = json.load(f) data = json.load(f)
assert ( assert (
"1" and "2" in data["text_id_to_ref_doc_id"] "1" and "2" in data["text_id_to_ref_doc_id"]