diff --git a/.github/workflows/unit-test.yaml b/.github/workflows/unit-test.yaml index a965366..d0ce722 100644 --- a/.github/workflows/unit-test.yaml +++ b/.github/workflows/unit-test.yaml @@ -6,6 +6,9 @@ on: push: branches: [main] +env: + THEFLOW_TEMP_PATH: ./tmp + jobs: unit-test: if: ${{ !cancelled() }} diff --git a/knowledgehub/agents/__init__.py b/knowledgehub/agents/__init__.py index 349173c..81a18df 100644 --- a/knowledgehub/agents/__init__.py +++ b/knowledgehub/agents/__init__.py @@ -1,6 +1,25 @@ -from .base import AgentType, BaseAgent -from .langchain import LangchainAgent +from .base import BaseAgent +from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad +from .langchain_based import LangchainAgent from .react.agent import ReactAgent from .rewoo.agent import RewooAgent +from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool -__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"] +__all__ = [ + # agent + "BaseAgent", + "ReactAgent", + "RewooAgent", + "LangchainAgent", + # tool + "BaseTool", + "ComponentTool", + "GoogleSearchTool", + "WikipediaTool", + "LLMTool", + # io + "AgentType", + "AgentOutput", + "AgentFinish", + "BaseScratchPad", +] diff --git a/knowledgehub/agents/base.py b/knowledgehub/agents/base.py index 44ec9b1..acfddcb 100644 --- a/knowledgehub/agents/base.py +++ b/knowledgehub/agents/base.py @@ -1,45 +1,13 @@ -from enum import Enum from typing import Optional, Union from theflow import Node, Param from kotaemon.base import BaseComponent -from kotaemon.llms import PromptTemplate -from kotaemon.llms.chats.base import ChatLLM -from kotaemon.llms.completions.base import LLM +from kotaemon.llms import BaseLLM, PromptTemplate +from .io import AgentOutput, AgentType from .tools import BaseTool -BaseLLM = Union[ChatLLM, LLM] - - -class AgentType(Enum): - """ - Enumerated type for agent types. - """ - - openai = "openai" - openai_multi = "openai_multi" - openai_tool = "openai_tool" - self_ask = "self_ask" - react = "react" - rewoo = "rewoo" - vanilla = "vanilla" - - @staticmethod - def get_agent_class(_type: "AgentType"): - """ - Get agent class from agent type. - :param _type: agent type - :return: agent class - """ - if _type == AgentType.rewoo: - from .rewoo.agent import RewooAgent - - return RewooAgent - else: - raise ValueError(f"Unknown agent type: {_type}") - class BaseAgent(BaseComponent): """Define base agent interface""" @@ -47,13 +15,17 @@ class BaseAgent(BaseComponent): name: str = Param(help="Name of the agent.") agent_type: AgentType = Param(help="Agent type, must be one of AgentType") description: str = Param( - help="Description used to tell the model how/when/why to use the agent. " - "You can provide few-shot examples as a part of the description. This will be " - "input to the prompt of LLM." + help=( + "Description used to tell the model how/when/why to use the agent. You can" + " provide few-shot examples as a part of the description. This will be" + " input to the prompt of LLM." + ) ) - llm: Union[BaseLLM, dict[str, BaseLLM]] = Node( - help="Specify LLM to be used in the model, cam be a dict to supply different " - "LLMs to multiple purposes in the agent" + llm: Optional[BaseLLM] = Node( + help=( + "LLM to be used for the agent (optional). LLM must implement BaseLLM" + " interface." + ) ) prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param( help="A prompt template or a dict to supply different prompt to the agent" @@ -63,6 +35,25 @@ class BaseAgent(BaseComponent): help="List of plugins / tools to be used in the agent", ) + @staticmethod + def safeguard_run(run_func, *args, **kwargs): + def wrapper(self, *args, **kwargs): + try: + return run_func(self, *args, **kwargs) + except Exception as e: + return AgentOutput( + text="", + agent_type=self.agent_type, + status="failed", + error=str(e), + ) + + return wrapper + def add_tools(self, tools: list[BaseTool]) -> None: """Helper method to add tools and update agent state if needed""" self.plugins.extend(tools) + + def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]: + """Run the component.""" + raise NotImplementedError() diff --git a/knowledgehub/agents/io/__init__.py b/knowledgehub/agents/io/__init__.py new file mode 100644 index 0000000..95f1874 --- /dev/null +++ b/knowledgehub/agents/io/__init__.py @@ -0,0 +1,3 @@ +from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad + +__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"] diff --git a/knowledgehub/agents/output/base.py b/knowledgehub/agents/io/base.py similarity index 84% rename from knowledgehub/agents/output/base.py rename to knowledgehub/agents/io/base.py index 242daa5..c27eed0 100644 --- a/knowledgehub/agents/output/base.py +++ b/knowledgehub/agents/io/base.py @@ -2,7 +2,12 @@ import json import logging import os from dataclasses import dataclass -from typing import Any, Dict, NamedTuple, Union +from enum import Enum +from typing import Any, Dict, Literal, NamedTuple, Optional, Union + +from pydantic import Extra + +from kotaemon.base import LLMInterface def check_log(): @@ -14,6 +19,20 @@ def check_log(): return os.environ.get("LOG_PATH", None) is not None +class AgentType(Enum): + """ + Enumerated type for agent types. + """ + + openai = "openai" + openai_multi = "openai_multi" + openai_tool = "openai_tool" + self_ask = "self_ask" + react = "react" + rewoo = "rewoo" + vanilla = "vanilla" + + class BaseScratchPad: """ Base class for output handlers. @@ -217,3 +236,20 @@ class AgentFinish(NamedTuple): return_values: dict log: str + + +class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg] + """Output from an agent. + + Args: + text: The text output from the agent. + agent_type: The type of agent. + status: The status after executing the agent. + error: The error message if any. + """ + + text: str + type: str = "agent" + agent_type: AgentType + status: Literal["finished", "stopped", "failed"] + error: Optional[str] = None diff --git a/knowledgehub/agents/langchain.py b/knowledgehub/agents/langchain_based.py similarity index 83% rename from knowledgehub/agents/langchain.py rename to knowledgehub/agents/langchain_based.py index c36f7b7..8189b6f 100644 --- a/knowledgehub/agents/langchain.py +++ b/knowledgehub/agents/langchain_based.py @@ -4,12 +4,11 @@ from langchain.agents import AgentType as LCAgentType from langchain.agents import initialize_agent from langchain.agents.agent import AgentExecutor as LCAgentExecutor -from kotaemon.agents.tools import BaseTool -from kotaemon.base.schema import Document -from kotaemon.llms.chats.base import ChatLLM -from kotaemon.llms.completions.base import LLM +from kotaemon.llms import LLM, ChatLLM -from .base import AgentType, BaseAgent +from .base import BaseAgent +from .io import AgentOutput, AgentType +from .tools import BaseTool class LangchainAgent(BaseAgent): @@ -54,7 +53,9 @@ class LangchainAgent(BaseAgent): # reinit Langchain AgentExecutor self.agent = initialize_agent( langchain_plugins, - self.llm._obj, + # TODO: could cause bugs for non-langchain llms + # related to https://github.com/Cinnamon/kotaemon/issues/73 + self.llm._obj, # type: ignore agent=self.AGENT_TYPE_MAP[self.agent_type], handle_parsing_errors=True, verbose=True, @@ -65,17 +66,16 @@ class LangchainAgent(BaseAgent): self.update_agent_tools() return - def run(self, instruction: str) -> Document: + def run(self, instruction: str) -> AgentOutput: assert ( self.agent is not None ), "Lanchain AgentExecutor is not correclty initialized" + # Langchain AgentExecutor call output = self.agent(instruction)["output"] - return Document( + + return AgentOutput( text=output, - metadata={ - "agent": "langchain", - "cost": 0.0, - "usage": 0, - }, + agent_type=self.agent_type, + status="finished", ) diff --git a/knowledgehub/agents/output/__init__.py b/knowledgehub/agents/output/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/knowledgehub/agents/react/agent.py b/knowledgehub/agents/react/agent.py index cf70fcd..abb70bb 100644 --- a/knowledgehub/agents/react/agent.py +++ b/knowledgehub/agents/react/agent.py @@ -4,12 +4,11 @@ from typing import Optional from theflow import Param -from kotaemon.base.schema import Document +from kotaemon.agents.base import BaseAgent, BaseLLM +from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType +from kotaemon.agents.tools import BaseTool from kotaemon.llms import PromptTemplate -from ..base import AgentType, BaseAgent, BaseLLM, BaseTool -from ..output.base import AgentAction, AgentFinish - FINAL_ANSWER_ACTION = "Final Answer:" @@ -22,7 +21,7 @@ class ReactAgent(BaseAgent): name: str = "ReactAgent" agent_type: AgentType = AgentType.react description: str = "ReactAgent for answering multi-step reasoning questions" - llm: BaseLLM | dict[str, BaseLLM] + llm: BaseLLM prompt_template: Optional[PromptTemplate] = None plugins: list[BaseTool] = Param( default_callback=lambda _: [], help="List of tools to be used in the agent. " @@ -34,7 +33,7 @@ class ReactAgent(BaseAgent): default_callback=lambda _: [], help="List of AgentAction and observation (tool) output", ) - max_iterations = 10 + max_iterations: int = 10 strict_decode: bool = False def _compose_plugin_description(self) -> str: @@ -141,7 +140,7 @@ class ReactAgent(BaseAgent): """ self.intermediate_steps = [] - def run(self, instruction, max_iterations=None): + def run(self, instruction, max_iterations=None) -> AgentOutput: """ Run the agent with the given instruction. @@ -161,11 +160,15 @@ class ReactAgent(BaseAgent): logging.info(f"Running {self.name} with instruction: {instruction}") total_cost = 0.0 total_token = 0 + status = "failed" + response_text = None - for _ in range(max_iterations): + for step_count in range(1, max_iterations + 1): prompt = self._compose_prompt(instruction) logging.info(f"Prompt: {prompt}") - response = self.llm(prompt, stop=["Observation:"]) # type: ignore + response = self.llm( + prompt, stop=["Observation:"] + ) # could cause bugs if llm doesn't have `stop` as a parameter response_text = response.text logging.info(f"Response: {response_text}") action_step = self._parse_output(response_text) @@ -185,13 +188,18 @@ class ReactAgent(BaseAgent): self.intermediate_steps.append((action_step, result)) if is_finished_chain: + logging.info(f"Finished after {step_count} steps.") + status = "finished" break + else: + status = "stopped" - return Document( + return AgentOutput( text=response_text, - metadata={ - "agent": "react", - "cost": total_cost, - "usage": total_token, - }, + agent_type=self.agent_type, + status=status, + total_tokens=total_token, + total_cost=total_cost, + intermediate_steps=self.intermediate_steps, + max_iterations=max_iterations, ) diff --git a/knowledgehub/agents/rewoo/agent.py b/knowledgehub/agents/rewoo/agent.py index 3831c29..5c6fba3 100644 --- a/knowledgehub/agents/rewoo/agent.py +++ b/knowledgehub/agents/rewoo/agent.py @@ -3,15 +3,15 @@ import re from concurrent.futures import ThreadPoolExecutor from typing import Any -from theflow import Param +from theflow import Node, Param -from kotaemon.base.schema import Document +from kotaemon.agents.base import BaseAgent +from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad +from kotaemon.agents.tools import BaseTool +from kotaemon.agents.utils import get_plugin_response_content from kotaemon.indices.qa import CitationPipeline -from kotaemon.llms import LLM, ChatLLM, PromptTemplate +from kotaemon.llms import BaseLLM, PromptTemplate -from ..base import AgentType, BaseAgent, BaseLLM, BaseTool -from ..output.base import BaseScratchPad -from ..utils import get_plugin_response_content from .planner import Planner from .solver import Solver @@ -23,7 +23,8 @@ class RewooAgent(BaseAgent): name: str = "RewooAgent" agent_type: AgentType = AgentType.rewoo description: str = "RewooAgent for answering multi-step reasoning questions" - llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx} + planner_llm: BaseLLM + solver_llm: BaseLLM prompt_template: dict[str, PromptTemplate] = Param( default_callback=lambda _: {}, help="A dict to supply different prompt to the agent.", @@ -35,17 +36,22 @@ class RewooAgent(BaseAgent): default_callback=lambda _: {}, help="Examples to be used in the agent." ) - def _get_llms(self): - if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM): - return {"Planner": self.llm, "Solver": self.llm} - elif ( - isinstance(self.llm, dict) - and "Planner" in self.llm - and "Solver" in self.llm - ): - return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]} - else: - raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.") + @Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"]) + def planner(self): + return Planner( + model=self.planner_llm, + plugins=self.plugins, + prompt_template=self.prompt_template.get("Planner", None), + examples=self.examples.get("Planner", None), + ) + + @Node.auto(depends_on=["solver_llm", "prompt_template", "examples"]) + def solver(self): + return Solver( + model=self.solver_llm, + prompt_template=self.prompt_template.get("Solver", None), + examples=self.examples.get("Solver", None), + ) def _parse_plan_map( self, planner_response: str @@ -76,13 +82,16 @@ class RewooAgent(BaseAgent): plan_to_es: dict[str, list[str]] = dict() plans: dict[str, str] = dict() + prev_key = "" for line in valid_chunk: - if line.startswith("#Plan"): - plan = line.split(":", 1)[0].strip() - plans[plan] = line.split(":", 1)[1].strip() - plan_to_es[plan] = [] - elif line.startswith("#E"): - plan_to_es[plan].append(line.split(":", 1)[0].strip()) + key, description = line.split(":", 1) + key = key.strip() + if key.startswith("#Plan"): + plans[key] = description.strip() + plan_to_es[key] = [] + prev_key = key + elif key.startswith("#E"): + plan_to_es[prev_key].append(key) return plan_to_es, plans @@ -218,7 +227,8 @@ class RewooAgent(BaseAgent): if p.name == name: return p - def run(self, instruction: str, use_citation: bool = False) -> Document: + @BaseAgent.safeguard_run + def run(self, instruction: str, use_citation: bool = False) -> AgentOutput: """ Run the agent with a given instruction. """ @@ -226,27 +236,12 @@ class RewooAgent(BaseAgent): total_cost = 0.0 total_token = 0 - planner_llm = self._get_llms()["Planner"] - solver_llm = self._get_llms()["Solver"] - - planner = Planner( - model=planner_llm, - plugins=self.plugins, - prompt_template=self.prompt_template.get("Planner", None), - examples=self.examples.get("Planner", None), - ) - solver = Solver( - model=solver_llm, - prompt_template=self.prompt_template.get("Solver", None), - examples=self.examples.get("Solver", None), - ) - # Plan - planner_output = planner(instruction) - plannner_text_output = planner_output.text - plan_to_es, plans = self._parse_plan_map(plannner_text_output) + planner_output = self.planner(instruction) + planner_text_output = planner_output.text + plan_to_es, plans = self._parse_plan_map(planner_text_output) planner_evidences, evidence_level = self._parse_planner_evidences( - plannner_text_output + planner_text_output ) # Work @@ -260,20 +255,19 @@ class RewooAgent(BaseAgent): worker_log += f"{e}: {worker_evidences[e]}\n" # Solve - solver_output = solver(instruction, worker_log) + solver_output = self.solver(instruction, worker_log) solver_output_text = solver_output.text if use_citation: - citation_pipeline = CitationPipeline(llm=solver_llm) + citation_pipeline = CitationPipeline(llm=self.solver_llm) citation = citation_pipeline(context=worker_log, question=instruction) else: citation = None - return Document( + return AgentOutput( text=solver_output_text, - metadata={ - "agent": "react", - "cost": total_cost, - "usage": total_token, - "citation": citation, - }, + agent_type=self.agent_type, + status="finished", + total_tokens=total_token, + total_cost=total_cost, + citation=citation, ) diff --git a/knowledgehub/agents/rewoo/planner.py b/knowledgehub/agents/rewoo/planner.py index 51af140..588dba4 100644 --- a/knowledgehub/agents/rewoo/planner.py +++ b/knowledgehub/agents/rewoo/planner.py @@ -1,10 +1,10 @@ from typing import Any, List, Optional, Union +from kotaemon.agents.base import BaseLLM, BaseTool +from kotaemon.agents.io import BaseScratchPad from kotaemon.base import BaseComponent from kotaemon.llms import PromptTemplate -from ..base import BaseLLM, BaseTool -from ..output.base import BaseScratchPad from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt diff --git a/knowledgehub/agents/rewoo/solver.py b/knowledgehub/agents/rewoo/solver.py index d2ce271..9968bff 100644 --- a/knowledgehub/agents/rewoo/solver.py +++ b/knowledgehub/agents/rewoo/solver.py @@ -1,10 +1,9 @@ from typing import Any, List, Optional, Union +from kotaemon.agents.io import BaseScratchPad from kotaemon.base import BaseComponent -from kotaemon.llms import PromptTemplate +from kotaemon.llms import BaseLLM, PromptTemplate -from ..base import BaseLLM -from ..output.base import BaseScratchPad from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt diff --git a/knowledgehub/agents/tools/google.py b/knowledgehub/agents/tools/google.py index 2c257fa..ec84568 100644 --- a/knowledgehub/agents/tools/google.py +++ b/knowledgehub/agents/tools/google.py @@ -11,8 +11,8 @@ class GoogleSearchArgs(BaseModel): class GoogleSearchTool(BaseTool): - name = "google_search" - description = ( + name: str = "google_search" + description: str = ( "A search engine retrieving top search results as snippets from Google. " "Input should be a search query." ) diff --git a/knowledgehub/agents/tools/llm.py b/knowledgehub/agents/tools/llm.py index 25e9440..62c6fef 100644 --- a/knowledgehub/agents/tools/llm.py +++ b/knowledgehub/agents/tools/llm.py @@ -14,8 +14,8 @@ class LLMArgs(BaseModel): class LLMTool(BaseTool): - name = "llm" - description = ( + name: str = "llm" + description: str = ( "A pretrained LLM like yourself. Useful when you need to act with " "general world knowledge and common sense. Prioritize it when you " "are confident in solving the problem " diff --git a/knowledgehub/agents/tools/wikipedia.py b/knowledgehub/agents/tools/wikipedia.py index f060154..9e6a362 100644 --- a/knowledgehub/agents/tools/wikipedia.py +++ b/knowledgehub/agents/tools/wikipedia.py @@ -48,8 +48,8 @@ class WikipediaArgs(BaseModel): class WikipediaTool(BaseTool): """Tool that adds the capability to query the Wikipedia API.""" - name = "wikipedia" - description = ( + name: str = "wikipedia" + description: str = ( "Search engine from Wikipedia, retrieving relevant wiki page. " "Useful when you need to get holistic knowledge about people, " "places, companies, historical events, or other subjects. " diff --git a/knowledgehub/base/schema.py b/knowledgehub/base/schema.py index a767e7f..cfb5791 100644 --- a/knowledgehub/base/schema.py +++ b/knowledgehub/base/schema.py @@ -114,6 +114,7 @@ class LLMInterface(AIMessage): completion_tokens: int = -1 total_tokens: int = -1 prompt_tokens: int = -1 + total_cost: float = 0 logits: list[list[float]] = Field(default_factory=list) messages: list[AIMessage] = Field(default_factory=list) diff --git a/knowledgehub/cli.py b/knowledgehub/cli.py index fc69858..75101ff 100644 --- a/knowledgehub/cli.py +++ b/knowledgehub/cli.py @@ -2,6 +2,7 @@ import os import click import yaml +from trogon import tui # check if the output is not a .yml file -> raise error @@ -14,6 +15,7 @@ def check_config_format(config): raise ValueError("config must be yaml format.") +@tui(command="ui", help="Open the terminal UI") # generate the terminal UI @click.group() def main(): pass @@ -56,8 +58,10 @@ def export(export_path, output): @click.option( "--username", required=False, - help="Username for the user. If not provided, the promptui will not have " - "authentication.", + help=( + "Username for the user. If not provided, the promptui will not have " + "authentication." + ), ) @click.option( "--password", diff --git a/knowledgehub/llms/chats/__init__.py b/knowledgehub/llms/chats/__init__.py index 9388c5f..ffdc139 100644 --- a/knowledgehub/llms/chats/__init__.py +++ b/knowledgehub/llms/chats/__init__.py @@ -1,4 +1,4 @@ from .base import ChatLLM -from .langchain_based import AzureChatOpenAI +from .langchain_based import AzureChatOpenAI, LCChatMixin -__all__ = ["ChatLLM", "AzureChatOpenAI"] +__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"] diff --git a/knowledgehub/llms/completions/__init__.py b/knowledgehub/llms/completions/__init__.py index 3ef6b8b..b0f6b0e 100644 --- a/knowledgehub/llms/completions/__init__.py +++ b/knowledgehub/llms/completions/__init__.py @@ -1,4 +1,4 @@ from .base import LLM -from .langchain_based import AzureOpenAI, OpenAI +from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI -__all__ = ["LLM", "OpenAI", "AzureOpenAI"] +__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"] diff --git a/pyproject.toml b/pyproject.toml index 300af24..7c8a650 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"] # metadata and dependencies [project] name = "kotaemon" -version = "0.3.3" +version = "0.3.4" requires-python = ">= 3.10" description = "Kotaemon core library for AI development." dependencies = [ @@ -24,6 +24,7 @@ dependencies = [ "cookiecutter", "click", "pandas", + "trogon", ] readme = "README.md" license = { text = "MIT License" } diff --git a/tests/test_agent.py b/tests/test_agent.py index 4b7ea15..4a10060 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -3,73 +3,69 @@ from unittest.mock import patch import pytest from openai.types.chat.chat_completion import ChatCompletion -from kotaemon.agents.base import AgentType -from kotaemon.agents.langchain import LangchainAgent -from kotaemon.agents.react import ReactAgent -from kotaemon.agents.rewoo import RewooAgent -from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool +from kotaemon.agents import ( + AgentType, + BaseTool, + GoogleSearchTool, + LangchainAgent, + LLMTool, + ReactAgent, + RewooAgent, + WikipediaTool, +) from kotaemon.llms import AzureChatOpenAI FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" +REWOO_VALID_PLAN = ( + "#Plan1: Search for Cinnamon AI company on Google\n" + "#E1: google_search[Cinnamon AI company]\n" + "#Plan2: Search for Cinnamon on Wikipedia\n" + "#E2: wikipedia[Cinnamon]\n" +) +REWOO_INVALID_PLAN = ( + "#E1: google_search[Cinnamon AI company]\n" + "#Plan2: Search for Cinnamon on Wikipedia\n" + "#E2: wikipedia[Cinnamon]\n" +) + + +def generate_chat_completion_obj(text): + return ChatCompletion.parse_obj( + { + "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", + "object": "chat.completion", + "created": 1692338378, + "model": "gpt-35-turbo", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "finish_reason": "stop", + "message": { + "role": "assistant", + "content": text, + "function_call": None, + "tool_calls": None, + }, + } + ], + "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, + } + ) _openai_chat_completion_responses_rewoo = [ - ChatCompletion.parse_obj( - { - "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", - "object": "chat.completion", - "created": 1692338378, - "model": "gpt-35-turbo", - "system_fingerprint": None, - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "message": { - "role": "assistant", - "content": text, - "function_call": None, - "tool_calls": None, - }, - } - ], - "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, - } - ) - for text in [ - ( - "#Plan1: Search for Cinnamon AI company on Google\n" - "#E1: google_search[Cinnamon AI company]\n" - "#Plan2: Search for Cinnamon on Wikipedia\n" - "#E2: wikipedia[Cinnamon]\n" - ), - FINAL_RESPONSE_TEXT, - ] + generate_chat_completion_obj(text=text) + for text in [REWOO_VALID_PLAN, FINAL_RESPONSE_TEXT] +] + +_openai_chat_completion_responses_rewoo_error = [ + generate_chat_completion_obj(text=text) + for text in [REWOO_INVALID_PLAN, FINAL_RESPONSE_TEXT] ] _openai_chat_completion_responses_react = [ - ChatCompletion.parse_obj( - { - "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", - "object": "chat.completion", - "created": 1692338378, - "model": "gpt-35-turbo", - "system_fingerprint": None, - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "message": { - "role": "assistant", - "content": text, - "function_call": None, - "tool_calls": None, - }, - } - ], - "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, - } - ) + generate_chat_completion_obj(text=text) for text in [ ( "I don't have prior knowledge about Cinnamon AI company, " @@ -91,28 +87,7 @@ _openai_chat_completion_responses_react = [ ] _openai_chat_completion_responses_react_langchain_tool = [ - ChatCompletion.parse_obj( - { - "id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x", - "object": "chat.completion", - "created": 1692338378, - "model": "gpt-35-turbo", - "system_fingerprint": None, - "choices": [ - { - "index": 0, - "finish_reason": "stop", - "message": { - "role": "assistant", - "content": text, - "function_call": None, - "tool_calls": None, - }, - } - ], - "usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19}, - } - ) + generate_chat_completion_obj(text=text) for text in [ ( "I don't have prior knowledge about Cinnamon AI company, " @@ -145,6 +120,25 @@ def llm(): ) +@patch( + "openai.resources.chat.completions.Completions.create", + side_effect=_openai_chat_completion_responses_rewoo_error, +) +def test_agent_fail(openai_completion, llm, mock_google_search): + plugins = [ + GoogleSearchTool(), + WikipediaTool(), + LLMTool(llm=llm), + ] + + agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins) + + response = agent("Tell me about Cinnamon AI company") + openai_completion.assert_called() + assert not response + assert response.status == "failed" + + @patch( "openai.resources.chat.completions.Completions.create", side_effect=_openai_chat_completion_responses_rewoo, @@ -156,7 +150,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search): LLMTool(llm=llm), ] - agent = RewooAgent(llm=llm, plugins=plugins) + agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins) response = agent("Tell me about Cinnamon AI company") openai_completion.assert_called() diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index fc9a30c..e086d0e 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -110,8 +110,8 @@ class TestInMemoryVectorStore: db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) db.delete(["3"]) db.save(save_path=tmp_path / "test_save_load_delete.json") - f = open(tmp_path / "test_save_load_delete.json") - data = json.load(f) + with open(tmp_path / "test_save_load_delete.json") as f: + data = json.load(f) assert ( "1" and "2" in data["text_id_to_ref_doc_id"] ), "save function does not save data completely" @@ -136,8 +136,8 @@ class TestSimpleFileVectorStore: db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json") db.add(embeddings=embeddings, metadatas=metadatas, ids=ids) db.delete(["3"]) - f = open(tmp_path / "test_save_load_delete.json") - data = json.load(f) + with open(tmp_path / "test_save_load_delete.json") as f: + data = json.load(f) assert ( "1" and "2" in data["text_id_to_ref_doc_id"] ), "save function does not save data completely"