refractor agents (#100)
* refractor agents * minor cosmetic, add terminal ui for cli * pump to 0.3.4 * Add temporary path * fix unclose files in tests --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
d9e925eb75
commit
797df5a69c
3
.github/workflows/unit-test.yaml
vendored
3
.github/workflows/unit-test.yaml
vendored
|
@ -6,6 +6,9 @@ on:
|
||||||
push:
|
push:
|
||||||
branches: [main]
|
branches: [main]
|
||||||
|
|
||||||
|
env:
|
||||||
|
THEFLOW_TEMP_PATH: ./tmp
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
unit-test:
|
unit-test:
|
||||||
if: ${{ !cancelled() }}
|
if: ${{ !cancelled() }}
|
||||||
|
|
|
@ -1,6 +1,25 @@
|
||||||
from .base import AgentType, BaseAgent
|
from .base import BaseAgent
|
||||||
from .langchain import LangchainAgent
|
from .io import AgentFinish, AgentOutput, AgentType, BaseScratchPad
|
||||||
|
from .langchain_based import LangchainAgent
|
||||||
from .react.agent import ReactAgent
|
from .react.agent import ReactAgent
|
||||||
from .rewoo.agent import RewooAgent
|
from .rewoo.agent import RewooAgent
|
||||||
|
from .tools import BaseTool, ComponentTool, GoogleSearchTool, LLMTool, WikipediaTool
|
||||||
|
|
||||||
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"]
|
__all__ = [
|
||||||
|
# agent
|
||||||
|
"BaseAgent",
|
||||||
|
"ReactAgent",
|
||||||
|
"RewooAgent",
|
||||||
|
"LangchainAgent",
|
||||||
|
# tool
|
||||||
|
"BaseTool",
|
||||||
|
"ComponentTool",
|
||||||
|
"GoogleSearchTool",
|
||||||
|
"WikipediaTool",
|
||||||
|
"LLMTool",
|
||||||
|
# io
|
||||||
|
"AgentType",
|
||||||
|
"AgentOutput",
|
||||||
|
"AgentFinish",
|
||||||
|
"BaseScratchPad",
|
||||||
|
]
|
||||||
|
|
|
@ -1,45 +1,13 @@
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||||
from kotaemon.llms.chats.base import ChatLLM
|
|
||||||
from kotaemon.llms.completions.base import LLM
|
|
||||||
|
|
||||||
|
from .io import AgentOutput, AgentType
|
||||||
from .tools import BaseTool
|
from .tools import BaseTool
|
||||||
|
|
||||||
BaseLLM = Union[ChatLLM, LLM]
|
|
||||||
|
|
||||||
|
|
||||||
class AgentType(Enum):
|
|
||||||
"""
|
|
||||||
Enumerated type for agent types.
|
|
||||||
"""
|
|
||||||
|
|
||||||
openai = "openai"
|
|
||||||
openai_multi = "openai_multi"
|
|
||||||
openai_tool = "openai_tool"
|
|
||||||
self_ask = "self_ask"
|
|
||||||
react = "react"
|
|
||||||
rewoo = "rewoo"
|
|
||||||
vanilla = "vanilla"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def get_agent_class(_type: "AgentType"):
|
|
||||||
"""
|
|
||||||
Get agent class from agent type.
|
|
||||||
:param _type: agent type
|
|
||||||
:return: agent class
|
|
||||||
"""
|
|
||||||
if _type == AgentType.rewoo:
|
|
||||||
from .rewoo.agent import RewooAgent
|
|
||||||
|
|
||||||
return RewooAgent
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown agent type: {_type}")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(BaseComponent):
|
class BaseAgent(BaseComponent):
|
||||||
"""Define base agent interface"""
|
"""Define base agent interface"""
|
||||||
|
@ -47,13 +15,17 @@ class BaseAgent(BaseComponent):
|
||||||
name: str = Param(help="Name of the agent.")
|
name: str = Param(help="Name of the agent.")
|
||||||
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
|
agent_type: AgentType = Param(help="Agent type, must be one of AgentType")
|
||||||
description: str = Param(
|
description: str = Param(
|
||||||
help="Description used to tell the model how/when/why to use the agent. "
|
help=(
|
||||||
"You can provide few-shot examples as a part of the description. This will be "
|
"Description used to tell the model how/when/why to use the agent. You can"
|
||||||
|
" provide few-shot examples as a part of the description. This will be"
|
||||||
" input to the prompt of LLM."
|
" input to the prompt of LLM."
|
||||||
)
|
)
|
||||||
llm: Union[BaseLLM, dict[str, BaseLLM]] = Node(
|
)
|
||||||
help="Specify LLM to be used in the model, cam be a dict to supply different "
|
llm: Optional[BaseLLM] = Node(
|
||||||
"LLMs to multiple purposes in the agent"
|
help=(
|
||||||
|
"LLM to be used for the agent (optional). LLM must implement BaseLLM"
|
||||||
|
" interface."
|
||||||
|
)
|
||||||
)
|
)
|
||||||
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
|
prompt_template: Optional[Union[PromptTemplate, dict[str, PromptTemplate]]] = Param(
|
||||||
help="A prompt template or a dict to supply different prompt to the agent"
|
help="A prompt template or a dict to supply different prompt to the agent"
|
||||||
|
@ -63,6 +35,25 @@ class BaseAgent(BaseComponent):
|
||||||
help="List of plugins / tools to be used in the agent",
|
help="List of plugins / tools to be used in the agent",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def safeguard_run(run_func, *args, **kwargs):
|
||||||
|
def wrapper(self, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
return run_func(self, *args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
return AgentOutput(
|
||||||
|
text="",
|
||||||
|
agent_type=self.agent_type,
|
||||||
|
status="failed",
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
def add_tools(self, tools: list[BaseTool]) -> None:
|
def add_tools(self, tools: list[BaseTool]) -> None:
|
||||||
"""Helper method to add tools and update agent state if needed"""
|
"""Helper method to add tools and update agent state if needed"""
|
||||||
self.plugins.extend(tools)
|
self.plugins.extend(tools)
|
||||||
|
|
||||||
|
def run(self, *args, **kwargs) -> AgentOutput | list[AgentOutput]:
|
||||||
|
"""Run the component."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
3
knowledgehub/agents/io/__init__.py
Normal file
3
knowledgehub/agents/io/__init__.py
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
from .base import AgentAction, AgentFinish, AgentOutput, AgentType, BaseScratchPad
|
||||||
|
|
||||||
|
__all__ = ["AgentOutput", "AgentFinish", "BaseScratchPad", "AgentType", "AgentAction"]
|
|
@ -2,7 +2,12 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, NamedTuple, Union
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Literal, NamedTuple, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Extra
|
||||||
|
|
||||||
|
from kotaemon.base import LLMInterface
|
||||||
|
|
||||||
|
|
||||||
def check_log():
|
def check_log():
|
||||||
|
@ -14,6 +19,20 @@ def check_log():
|
||||||
return os.environ.get("LOG_PATH", None) is not None
|
return os.environ.get("LOG_PATH", None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(Enum):
|
||||||
|
"""
|
||||||
|
Enumerated type for agent types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
openai = "openai"
|
||||||
|
openai_multi = "openai_multi"
|
||||||
|
openai_tool = "openai_tool"
|
||||||
|
self_ask = "self_ask"
|
||||||
|
react = "react"
|
||||||
|
rewoo = "rewoo"
|
||||||
|
vanilla = "vanilla"
|
||||||
|
|
||||||
|
|
||||||
class BaseScratchPad:
|
class BaseScratchPad:
|
||||||
"""
|
"""
|
||||||
Base class for output handlers.
|
Base class for output handlers.
|
||||||
|
@ -217,3 +236,20 @@ class AgentFinish(NamedTuple):
|
||||||
|
|
||||||
return_values: dict
|
return_values: dict
|
||||||
log: str
|
log: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentOutput(LLMInterface, extra=Extra.allow): # type: ignore [call-arg]
|
||||||
|
"""Output from an agent.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The text output from the agent.
|
||||||
|
agent_type: The type of agent.
|
||||||
|
status: The status after executing the agent.
|
||||||
|
error: The error message if any.
|
||||||
|
"""
|
||||||
|
|
||||||
|
text: str
|
||||||
|
type: str = "agent"
|
||||||
|
agent_type: AgentType
|
||||||
|
status: Literal["finished", "stopped", "failed"]
|
||||||
|
error: Optional[str] = None
|
|
@ -4,12 +4,11 @@ from langchain.agents import AgentType as LCAgentType
|
||||||
from langchain.agents import initialize_agent
|
from langchain.agents import initialize_agent
|
||||||
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||||
|
|
||||||
from kotaemon.agents.tools import BaseTool
|
from kotaemon.llms import LLM, ChatLLM
|
||||||
from kotaemon.base.schema import Document
|
|
||||||
from kotaemon.llms.chats.base import ChatLLM
|
|
||||||
from kotaemon.llms.completions.base import LLM
|
|
||||||
|
|
||||||
from .base import AgentType, BaseAgent
|
from .base import BaseAgent
|
||||||
|
from .io import AgentOutput, AgentType
|
||||||
|
from .tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
class LangchainAgent(BaseAgent):
|
class LangchainAgent(BaseAgent):
|
||||||
|
@ -54,7 +53,9 @@ class LangchainAgent(BaseAgent):
|
||||||
# reinit Langchain AgentExecutor
|
# reinit Langchain AgentExecutor
|
||||||
self.agent = initialize_agent(
|
self.agent = initialize_agent(
|
||||||
langchain_plugins,
|
langchain_plugins,
|
||||||
self.llm._obj,
|
# TODO: could cause bugs for non-langchain llms
|
||||||
|
# related to https://github.com/Cinnamon/kotaemon/issues/73
|
||||||
|
self.llm._obj, # type: ignore
|
||||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||||
handle_parsing_errors=True,
|
handle_parsing_errors=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
@ -65,17 +66,16 @@ class LangchainAgent(BaseAgent):
|
||||||
self.update_agent_tools()
|
self.update_agent_tools()
|
||||||
return
|
return
|
||||||
|
|
||||||
def run(self, instruction: str) -> Document:
|
def run(self, instruction: str) -> AgentOutput:
|
||||||
assert (
|
assert (
|
||||||
self.agent is not None
|
self.agent is not None
|
||||||
), "Lanchain AgentExecutor is not correclty initialized"
|
), "Lanchain AgentExecutor is not correclty initialized"
|
||||||
|
|
||||||
# Langchain AgentExecutor call
|
# Langchain AgentExecutor call
|
||||||
output = self.agent(instruction)["output"]
|
output = self.agent(instruction)["output"]
|
||||||
return Document(
|
|
||||||
|
return AgentOutput(
|
||||||
text=output,
|
text=output,
|
||||||
metadata={
|
agent_type=self.agent_type,
|
||||||
"agent": "langchain",
|
status="finished",
|
||||||
"cost": 0.0,
|
|
||||||
"usage": 0,
|
|
||||||
},
|
|
||||||
)
|
)
|
|
@ -4,12 +4,11 @@ from typing import Optional
|
||||||
|
|
||||||
from theflow import Param
|
from theflow import Param
|
||||||
|
|
||||||
from kotaemon.base.schema import Document
|
from kotaemon.agents.base import BaseAgent, BaseLLM
|
||||||
|
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
|
||||||
|
from kotaemon.agents.tools import BaseTool
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
|
|
||||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
|
||||||
from ..output.base import AgentAction, AgentFinish
|
|
||||||
|
|
||||||
FINAL_ANSWER_ACTION = "Final Answer:"
|
FINAL_ANSWER_ACTION = "Final Answer:"
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +21,7 @@ class ReactAgent(BaseAgent):
|
||||||
name: str = "ReactAgent"
|
name: str = "ReactAgent"
|
||||||
agent_type: AgentType = AgentType.react
|
agent_type: AgentType = AgentType.react
|
||||||
description: str = "ReactAgent for answering multi-step reasoning questions"
|
description: str = "ReactAgent for answering multi-step reasoning questions"
|
||||||
llm: BaseLLM | dict[str, BaseLLM]
|
llm: BaseLLM
|
||||||
prompt_template: Optional[PromptTemplate] = None
|
prompt_template: Optional[PromptTemplate] = None
|
||||||
plugins: list[BaseTool] = Param(
|
plugins: list[BaseTool] = Param(
|
||||||
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
default_callback=lambda _: [], help="List of tools to be used in the agent. "
|
||||||
|
@ -34,7 +33,7 @@ class ReactAgent(BaseAgent):
|
||||||
default_callback=lambda _: [],
|
default_callback=lambda _: [],
|
||||||
help="List of AgentAction and observation (tool) output",
|
help="List of AgentAction and observation (tool) output",
|
||||||
)
|
)
|
||||||
max_iterations = 10
|
max_iterations: int = 10
|
||||||
strict_decode: bool = False
|
strict_decode: bool = False
|
||||||
|
|
||||||
def _compose_plugin_description(self) -> str:
|
def _compose_plugin_description(self) -> str:
|
||||||
|
@ -141,7 +140,7 @@ class ReactAgent(BaseAgent):
|
||||||
"""
|
"""
|
||||||
self.intermediate_steps = []
|
self.intermediate_steps = []
|
||||||
|
|
||||||
def run(self, instruction, max_iterations=None):
|
def run(self, instruction, max_iterations=None) -> AgentOutput:
|
||||||
"""
|
"""
|
||||||
Run the agent with the given instruction.
|
Run the agent with the given instruction.
|
||||||
|
|
||||||
|
@ -161,11 +160,15 @@ class ReactAgent(BaseAgent):
|
||||||
logging.info(f"Running {self.name} with instruction: {instruction}")
|
logging.info(f"Running {self.name} with instruction: {instruction}")
|
||||||
total_cost = 0.0
|
total_cost = 0.0
|
||||||
total_token = 0
|
total_token = 0
|
||||||
|
status = "failed"
|
||||||
|
response_text = None
|
||||||
|
|
||||||
for _ in range(max_iterations):
|
for step_count in range(1, max_iterations + 1):
|
||||||
prompt = self._compose_prompt(instruction)
|
prompt = self._compose_prompt(instruction)
|
||||||
logging.info(f"Prompt: {prompt}")
|
logging.info(f"Prompt: {prompt}")
|
||||||
response = self.llm(prompt, stop=["Observation:"]) # type: ignore
|
response = self.llm(
|
||||||
|
prompt, stop=["Observation:"]
|
||||||
|
) # could cause bugs if llm doesn't have `stop` as a parameter
|
||||||
response_text = response.text
|
response_text = response.text
|
||||||
logging.info(f"Response: {response_text}")
|
logging.info(f"Response: {response_text}")
|
||||||
action_step = self._parse_output(response_text)
|
action_step = self._parse_output(response_text)
|
||||||
|
@ -185,13 +188,18 @@ class ReactAgent(BaseAgent):
|
||||||
|
|
||||||
self.intermediate_steps.append((action_step, result))
|
self.intermediate_steps.append((action_step, result))
|
||||||
if is_finished_chain:
|
if is_finished_chain:
|
||||||
|
logging.info(f"Finished after {step_count} steps.")
|
||||||
|
status = "finished"
|
||||||
break
|
break
|
||||||
|
else:
|
||||||
|
status = "stopped"
|
||||||
|
|
||||||
return Document(
|
return AgentOutput(
|
||||||
text=response_text,
|
text=response_text,
|
||||||
metadata={
|
agent_type=self.agent_type,
|
||||||
"agent": "react",
|
status=status,
|
||||||
"cost": total_cost,
|
total_tokens=total_token,
|
||||||
"usage": total_token,
|
total_cost=total_cost,
|
||||||
},
|
intermediate_steps=self.intermediate_steps,
|
||||||
|
max_iterations=max_iterations,
|
||||||
)
|
)
|
||||||
|
|
|
@ -3,15 +3,15 @@ import re
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from theflow import Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
from kotaemon.base.schema import Document
|
from kotaemon.agents.base import BaseAgent
|
||||||
|
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
|
||||||
|
from kotaemon.agents.tools import BaseTool
|
||||||
|
from kotaemon.agents.utils import get_plugin_response_content
|
||||||
from kotaemon.indices.qa import CitationPipeline
|
from kotaemon.indices.qa import CitationPipeline
|
||||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||||
|
|
||||||
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
|
||||||
from ..output.base import BaseScratchPad
|
|
||||||
from ..utils import get_plugin_response_content
|
|
||||||
from .planner import Planner
|
from .planner import Planner
|
||||||
from .solver import Solver
|
from .solver import Solver
|
||||||
|
|
||||||
|
@ -23,7 +23,8 @@ class RewooAgent(BaseAgent):
|
||||||
name: str = "RewooAgent"
|
name: str = "RewooAgent"
|
||||||
agent_type: AgentType = AgentType.rewoo
|
agent_type: AgentType = AgentType.rewoo
|
||||||
description: str = "RewooAgent for answering multi-step reasoning questions"
|
description: str = "RewooAgent for answering multi-step reasoning questions"
|
||||||
llm: BaseLLM | dict[str, BaseLLM] # {"Planner": xxx, "Solver": xxx}
|
planner_llm: BaseLLM
|
||||||
|
solver_llm: BaseLLM
|
||||||
prompt_template: dict[str, PromptTemplate] = Param(
|
prompt_template: dict[str, PromptTemplate] = Param(
|
||||||
default_callback=lambda _: {},
|
default_callback=lambda _: {},
|
||||||
help="A dict to supply different prompt to the agent.",
|
help="A dict to supply different prompt to the agent.",
|
||||||
|
@ -35,17 +36,22 @@ class RewooAgent(BaseAgent):
|
||||||
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
default_callback=lambda _: {}, help="Examples to be used in the agent."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_llms(self):
|
@Node.auto(depends_on=["planner_llm", "plugins", "prompt_template", "examples"])
|
||||||
if isinstance(self.llm, ChatLLM) or isinstance(self.llm, LLM):
|
def planner(self):
|
||||||
return {"Planner": self.llm, "Solver": self.llm}
|
return Planner(
|
||||||
elif (
|
model=self.planner_llm,
|
||||||
isinstance(self.llm, dict)
|
plugins=self.plugins,
|
||||||
and "Planner" in self.llm
|
prompt_template=self.prompt_template.get("Planner", None),
|
||||||
and "Solver" in self.llm
|
examples=self.examples.get("Planner", None),
|
||||||
):
|
)
|
||||||
return {"Planner": self.llm["Planner"], "Solver": self.llm["Solver"]}
|
|
||||||
else:
|
@Node.auto(depends_on=["solver_llm", "prompt_template", "examples"])
|
||||||
raise ValueError("llm must be a BaseLLM or a dict with Planner and Solver.")
|
def solver(self):
|
||||||
|
return Solver(
|
||||||
|
model=self.solver_llm,
|
||||||
|
prompt_template=self.prompt_template.get("Solver", None),
|
||||||
|
examples=self.examples.get("Solver", None),
|
||||||
|
)
|
||||||
|
|
||||||
def _parse_plan_map(
|
def _parse_plan_map(
|
||||||
self, planner_response: str
|
self, planner_response: str
|
||||||
|
@ -76,13 +82,16 @@ class RewooAgent(BaseAgent):
|
||||||
|
|
||||||
plan_to_es: dict[str, list[str]] = dict()
|
plan_to_es: dict[str, list[str]] = dict()
|
||||||
plans: dict[str, str] = dict()
|
plans: dict[str, str] = dict()
|
||||||
|
prev_key = ""
|
||||||
for line in valid_chunk:
|
for line in valid_chunk:
|
||||||
if line.startswith("#Plan"):
|
key, description = line.split(":", 1)
|
||||||
plan = line.split(":", 1)[0].strip()
|
key = key.strip()
|
||||||
plans[plan] = line.split(":", 1)[1].strip()
|
if key.startswith("#Plan"):
|
||||||
plan_to_es[plan] = []
|
plans[key] = description.strip()
|
||||||
elif line.startswith("#E"):
|
plan_to_es[key] = []
|
||||||
plan_to_es[plan].append(line.split(":", 1)[0].strip())
|
prev_key = key
|
||||||
|
elif key.startswith("#E"):
|
||||||
|
plan_to_es[prev_key].append(key)
|
||||||
|
|
||||||
return plan_to_es, plans
|
return plan_to_es, plans
|
||||||
|
|
||||||
|
@ -218,7 +227,8 @@ class RewooAgent(BaseAgent):
|
||||||
if p.name == name:
|
if p.name == name:
|
||||||
return p
|
return p
|
||||||
|
|
||||||
def run(self, instruction: str, use_citation: bool = False) -> Document:
|
@BaseAgent.safeguard_run
|
||||||
|
def run(self, instruction: str, use_citation: bool = False) -> AgentOutput:
|
||||||
"""
|
"""
|
||||||
Run the agent with a given instruction.
|
Run the agent with a given instruction.
|
||||||
"""
|
"""
|
||||||
|
@ -226,27 +236,12 @@ class RewooAgent(BaseAgent):
|
||||||
total_cost = 0.0
|
total_cost = 0.0
|
||||||
total_token = 0
|
total_token = 0
|
||||||
|
|
||||||
planner_llm = self._get_llms()["Planner"]
|
|
||||||
solver_llm = self._get_llms()["Solver"]
|
|
||||||
|
|
||||||
planner = Planner(
|
|
||||||
model=planner_llm,
|
|
||||||
plugins=self.plugins,
|
|
||||||
prompt_template=self.prompt_template.get("Planner", None),
|
|
||||||
examples=self.examples.get("Planner", None),
|
|
||||||
)
|
|
||||||
solver = Solver(
|
|
||||||
model=solver_llm,
|
|
||||||
prompt_template=self.prompt_template.get("Solver", None),
|
|
||||||
examples=self.examples.get("Solver", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Plan
|
# Plan
|
||||||
planner_output = planner(instruction)
|
planner_output = self.planner(instruction)
|
||||||
plannner_text_output = planner_output.text
|
planner_text_output = planner_output.text
|
||||||
plan_to_es, plans = self._parse_plan_map(plannner_text_output)
|
plan_to_es, plans = self._parse_plan_map(planner_text_output)
|
||||||
planner_evidences, evidence_level = self._parse_planner_evidences(
|
planner_evidences, evidence_level = self._parse_planner_evidences(
|
||||||
plannner_text_output
|
planner_text_output
|
||||||
)
|
)
|
||||||
|
|
||||||
# Work
|
# Work
|
||||||
|
@ -260,20 +255,19 @@ class RewooAgent(BaseAgent):
|
||||||
worker_log += f"{e}: {worker_evidences[e]}\n"
|
worker_log += f"{e}: {worker_evidences[e]}\n"
|
||||||
|
|
||||||
# Solve
|
# Solve
|
||||||
solver_output = solver(instruction, worker_log)
|
solver_output = self.solver(instruction, worker_log)
|
||||||
solver_output_text = solver_output.text
|
solver_output_text = solver_output.text
|
||||||
if use_citation:
|
if use_citation:
|
||||||
citation_pipeline = CitationPipeline(llm=solver_llm)
|
citation_pipeline = CitationPipeline(llm=self.solver_llm)
|
||||||
citation = citation_pipeline(context=worker_log, question=instruction)
|
citation = citation_pipeline(context=worker_log, question=instruction)
|
||||||
else:
|
else:
|
||||||
citation = None
|
citation = None
|
||||||
|
|
||||||
return Document(
|
return AgentOutput(
|
||||||
text=solver_output_text,
|
text=solver_output_text,
|
||||||
metadata={
|
agent_type=self.agent_type,
|
||||||
"agent": "react",
|
status="finished",
|
||||||
"cost": total_cost,
|
total_tokens=total_token,
|
||||||
"usage": total_token,
|
total_cost=total_cost,
|
||||||
"citation": citation,
|
citation=citation,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from kotaemon.agents.base import BaseLLM, BaseTool
|
||||||
|
from kotaemon.agents.io import BaseScratchPad
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
|
|
||||||
from ..base import BaseLLM, BaseTool
|
|
||||||
from ..output.base import BaseScratchPad
|
|
||||||
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
from .prompt import few_shot_planner_prompt, zero_shot_planner_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from typing import Any, List, Optional, Union
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
from kotaemon.agents.io import BaseScratchPad
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import BaseLLM, PromptTemplate
|
||||||
|
|
||||||
from ..base import BaseLLM
|
|
||||||
from ..output.base import BaseScratchPad
|
|
||||||
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
from .prompt import few_shot_solver_prompt, zero_shot_solver_prompt
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -11,8 +11,8 @@ class GoogleSearchArgs(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class GoogleSearchTool(BaseTool):
|
class GoogleSearchTool(BaseTool):
|
||||||
name = "google_search"
|
name: str = "google_search"
|
||||||
description = (
|
description: str = (
|
||||||
"A search engine retrieving top search results as snippets from Google. "
|
"A search engine retrieving top search results as snippets from Google. "
|
||||||
"Input should be a search query."
|
"Input should be a search query."
|
||||||
)
|
)
|
||||||
|
|
|
@ -14,8 +14,8 @@ class LLMArgs(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class LLMTool(BaseTool):
|
class LLMTool(BaseTool):
|
||||||
name = "llm"
|
name: str = "llm"
|
||||||
description = (
|
description: str = (
|
||||||
"A pretrained LLM like yourself. Useful when you need to act with "
|
"A pretrained LLM like yourself. Useful when you need to act with "
|
||||||
"general world knowledge and common sense. Prioritize it when you "
|
"general world knowledge and common sense. Prioritize it when you "
|
||||||
"are confident in solving the problem "
|
"are confident in solving the problem "
|
||||||
|
|
|
@ -48,8 +48,8 @@ class WikipediaArgs(BaseModel):
|
||||||
class WikipediaTool(BaseTool):
|
class WikipediaTool(BaseTool):
|
||||||
"""Tool that adds the capability to query the Wikipedia API."""
|
"""Tool that adds the capability to query the Wikipedia API."""
|
||||||
|
|
||||||
name = "wikipedia"
|
name: str = "wikipedia"
|
||||||
description = (
|
description: str = (
|
||||||
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
||||||
"Useful when you need to get holistic knowledge about people, "
|
"Useful when you need to get holistic knowledge about people, "
|
||||||
"places, companies, historical events, or other subjects. "
|
"places, companies, historical events, or other subjects. "
|
||||||
|
|
|
@ -114,6 +114,7 @@ class LLMInterface(AIMessage):
|
||||||
completion_tokens: int = -1
|
completion_tokens: int = -1
|
||||||
total_tokens: int = -1
|
total_tokens: int = -1
|
||||||
prompt_tokens: int = -1
|
prompt_tokens: int = -1
|
||||||
|
total_cost: float = 0
|
||||||
logits: list[list[float]] = Field(default_factory=list)
|
logits: list[list[float]] = Field(default_factory=list)
|
||||||
messages: list[AIMessage] = Field(default_factory=list)
|
messages: list[AIMessage] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@ import os
|
||||||
|
|
||||||
import click
|
import click
|
||||||
import yaml
|
import yaml
|
||||||
|
from trogon import tui
|
||||||
|
|
||||||
|
|
||||||
# check if the output is not a .yml file -> raise error
|
# check if the output is not a .yml file -> raise error
|
||||||
|
@ -14,6 +15,7 @@ def check_config_format(config):
|
||||||
raise ValueError("config must be yaml format.")
|
raise ValueError("config must be yaml format.")
|
||||||
|
|
||||||
|
|
||||||
|
@tui(command="ui", help="Open the terminal UI") # generate the terminal UI
|
||||||
@click.group()
|
@click.group()
|
||||||
def main():
|
def main():
|
||||||
pass
|
pass
|
||||||
|
@ -56,8 +58,10 @@ def export(export_path, output):
|
||||||
@click.option(
|
@click.option(
|
||||||
"--username",
|
"--username",
|
||||||
required=False,
|
required=False,
|
||||||
help="Username for the user. If not provided, the promptui will not have "
|
help=(
|
||||||
"authentication.",
|
"Username for the user. If not provided, the promptui will not have "
|
||||||
|
"authentication."
|
||||||
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--password",
|
"--password",
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .base import ChatLLM
|
from .base import ChatLLM
|
||||||
from .langchain_based import AzureChatOpenAI
|
from .langchain_based import AzureChatOpenAI, LCChatMixin
|
||||||
|
|
||||||
__all__ = ["ChatLLM", "AzureChatOpenAI"]
|
__all__ = ["ChatLLM", "AzureChatOpenAI", "LCChatMixin"]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .base import LLM
|
from .base import LLM
|
||||||
from .langchain_based import AzureOpenAI, OpenAI
|
from .langchain_based import AzureOpenAI, LCCompletionMixin, OpenAI
|
||||||
|
|
||||||
__all__ = ["LLM", "OpenAI", "AzureOpenAI"]
|
__all__ = ["LLM", "OpenAI", "AzureOpenAI", "LCCompletionMixin"]
|
||||||
|
|
|
@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
||||||
# metadata and dependencies
|
# metadata and dependencies
|
||||||
[project]
|
[project]
|
||||||
name = "kotaemon"
|
name = "kotaemon"
|
||||||
version = "0.3.3"
|
version = "0.3.4"
|
||||||
requires-python = ">= 3.10"
|
requires-python = ">= 3.10"
|
||||||
description = "Kotaemon core library for AI development."
|
description = "Kotaemon core library for AI development."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
@ -24,6 +24,7 @@ dependencies = [
|
||||||
"cookiecutter",
|
"cookiecutter",
|
||||||
"click",
|
"click",
|
||||||
"pandas",
|
"pandas",
|
||||||
|
"trogon",
|
||||||
]
|
]
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "MIT License" }
|
license = { text = "MIT License" }
|
||||||
|
|
|
@ -3,52 +3,34 @@ from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.agents.base import AgentType
|
from kotaemon.agents import (
|
||||||
from kotaemon.agents.langchain import LangchainAgent
|
AgentType,
|
||||||
from kotaemon.agents.react import ReactAgent
|
BaseTool,
|
||||||
from kotaemon.agents.rewoo import RewooAgent
|
GoogleSearchTool,
|
||||||
from kotaemon.agents.tools import BaseTool, GoogleSearchTool, LLMTool, WikipediaTool
|
LangchainAgent,
|
||||||
|
LLMTool,
|
||||||
|
ReactAgent,
|
||||||
|
RewooAgent,
|
||||||
|
WikipediaTool,
|
||||||
|
)
|
||||||
from kotaemon.llms import AzureChatOpenAI
|
from kotaemon.llms import AzureChatOpenAI
|
||||||
|
|
||||||
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
||||||
|
REWOO_VALID_PLAN = (
|
||||||
|
|
||||||
_openai_chat_completion_responses_rewoo = [
|
|
||||||
ChatCompletion.parse_obj(
|
|
||||||
{
|
|
||||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1692338378,
|
|
||||||
"model": "gpt-35-turbo",
|
|
||||||
"system_fingerprint": None,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text,
|
|
||||||
"function_call": None,
|
|
||||||
"tool_calls": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for text in [
|
|
||||||
(
|
|
||||||
"#Plan1: Search for Cinnamon AI company on Google\n"
|
"#Plan1: Search for Cinnamon AI company on Google\n"
|
||||||
"#E1: google_search[Cinnamon AI company]\n"
|
"#E1: google_search[Cinnamon AI company]\n"
|
||||||
"#Plan2: Search for Cinnamon on Wikipedia\n"
|
"#Plan2: Search for Cinnamon on Wikipedia\n"
|
||||||
"#E2: wikipedia[Cinnamon]\n"
|
"#E2: wikipedia[Cinnamon]\n"
|
||||||
),
|
)
|
||||||
FINAL_RESPONSE_TEXT,
|
REWOO_INVALID_PLAN = (
|
||||||
]
|
"#E1: google_search[Cinnamon AI company]\n"
|
||||||
]
|
"#Plan2: Search for Cinnamon on Wikipedia\n"
|
||||||
|
"#E2: wikipedia[Cinnamon]\n"
|
||||||
|
)
|
||||||
|
|
||||||
_openai_chat_completion_responses_react = [
|
|
||||||
ChatCompletion.parse_obj(
|
def generate_chat_completion_obj(text):
|
||||||
|
return ChatCompletion.parse_obj(
|
||||||
{
|
{
|
||||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
||||||
"object": "chat.completion",
|
"object": "chat.completion",
|
||||||
|
@ -70,6 +52,20 @@ _openai_chat_completion_responses_react = [
|
||||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_openai_chat_completion_responses_rewoo = [
|
||||||
|
generate_chat_completion_obj(text=text)
|
||||||
|
for text in [REWOO_VALID_PLAN, FINAL_RESPONSE_TEXT]
|
||||||
|
]
|
||||||
|
|
||||||
|
_openai_chat_completion_responses_rewoo_error = [
|
||||||
|
generate_chat_completion_obj(text=text)
|
||||||
|
for text in [REWOO_INVALID_PLAN, FINAL_RESPONSE_TEXT]
|
||||||
|
]
|
||||||
|
|
||||||
|
_openai_chat_completion_responses_react = [
|
||||||
|
generate_chat_completion_obj(text=text)
|
||||||
for text in [
|
for text in [
|
||||||
(
|
(
|
||||||
"I don't have prior knowledge about Cinnamon AI company, "
|
"I don't have prior knowledge about Cinnamon AI company, "
|
||||||
|
@ -91,28 +87,7 @@ _openai_chat_completion_responses_react = [
|
||||||
]
|
]
|
||||||
|
|
||||||
_openai_chat_completion_responses_react_langchain_tool = [
|
_openai_chat_completion_responses_react_langchain_tool = [
|
||||||
ChatCompletion.parse_obj(
|
generate_chat_completion_obj(text=text)
|
||||||
{
|
|
||||||
"id": "chatcmpl-7qyuw6Q1CFCpcKsMdFkmUPUa7JP2x",
|
|
||||||
"object": "chat.completion",
|
|
||||||
"created": 1692338378,
|
|
||||||
"model": "gpt-35-turbo",
|
|
||||||
"system_fingerprint": None,
|
|
||||||
"choices": [
|
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": "stop",
|
|
||||||
"message": {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": text,
|
|
||||||
"function_call": None,
|
|
||||||
"tool_calls": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"usage": {"completion_tokens": 9, "prompt_tokens": 10, "total_tokens": 19},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
for text in [
|
for text in [
|
||||||
(
|
(
|
||||||
"I don't have prior knowledge about Cinnamon AI company, "
|
"I don't have prior knowledge about Cinnamon AI company, "
|
||||||
|
@ -145,6 +120,25 @@ def llm():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.resources.chat.completions.Completions.create",
|
||||||
|
side_effect=_openai_chat_completion_responses_rewoo_error,
|
||||||
|
)
|
||||||
|
def test_agent_fail(openai_completion, llm, mock_google_search):
|
||||||
|
plugins = [
|
||||||
|
GoogleSearchTool(),
|
||||||
|
WikipediaTool(),
|
||||||
|
LLMTool(llm=llm),
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
|
||||||
|
|
||||||
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
|
openai_completion.assert_called()
|
||||||
|
assert not response
|
||||||
|
assert response.status == "failed"
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"openai.resources.chat.completions.Completions.create",
|
"openai.resources.chat.completions.Completions.create",
|
||||||
side_effect=_openai_chat_completion_responses_rewoo,
|
side_effect=_openai_chat_completion_responses_rewoo,
|
||||||
|
@ -156,7 +150,7 @@ def test_rewoo_agent(openai_completion, llm, mock_google_search):
|
||||||
LLMTool(llm=llm),
|
LLMTool(llm=llm),
|
||||||
]
|
]
|
||||||
|
|
||||||
agent = RewooAgent(llm=llm, plugins=plugins)
|
agent = RewooAgent(planner_llm=llm, solver_llm=llm, plugins=plugins)
|
||||||
|
|
||||||
response = agent("Tell me about Cinnamon AI company")
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
openai_completion.assert_called()
|
openai_completion.assert_called()
|
||||||
|
|
|
@ -110,7 +110,7 @@ class TestInMemoryVectorStore:
|
||||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
db.delete(["3"])
|
db.delete(["3"])
|
||||||
db.save(save_path=tmp_path / "test_save_load_delete.json")
|
db.save(save_path=tmp_path / "test_save_load_delete.json")
|
||||||
f = open(tmp_path / "test_save_load_delete.json")
|
with open(tmp_path / "test_save_load_delete.json") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
assert (
|
assert (
|
||||||
"1" and "2" in data["text_id_to_ref_doc_id"]
|
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||||
|
@ -136,7 +136,7 @@ class TestSimpleFileVectorStore:
|
||||||
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
|
db = SimpleFileVectorStore(path=tmp_path / "test_save_load_delete.json")
|
||||||
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
db.add(embeddings=embeddings, metadatas=metadatas, ids=ids)
|
||||||
db.delete(["3"])
|
db.delete(["3"])
|
||||||
f = open(tmp_path / "test_save_load_delete.json")
|
with open(tmp_path / "test_save_load_delete.json") as f:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
assert (
|
assert (
|
||||||
"1" and "2" in data["text_id_to_ref_doc_id"]
|
"1" and "2" in data["text_id_to_ref_doc_id"]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user