Add Langchain Agent wrapper with OpenAI Function / Self-ask agent support (#82)

* update Param() type hint in MVP

* update default embedding endpoint

* update Langchain agent wrapper

* update langchain agent
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-11-20 16:26:08 +07:00 committed by GitHub
parent 0a3fc4b228
commit 8bb7ad91e0
9 changed files with 137 additions and 35 deletions

View File

@ -1,5 +1,6 @@
from .base import BaseAgent from .base import AgentType, BaseAgent
from .langchain import LangchainAgent
from .react.agent import ReactAgent from .react.agent import ReactAgent
from .rewoo.agent import RewooAgent from .rewoo.agent import RewooAgent
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent"] __all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"]

View File

@ -1,8 +1,6 @@
from enum import Enum from enum import Enum
from typing import Dict, List, Optional, Union from typing import Dict, List, Optional, Union
from pydantic import BaseModel
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
from kotaemon.llms.chats.base import ChatLLM from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM from kotaemon.llms.completions.base import LLM
@ -17,10 +15,12 @@ class AgentType(Enum):
""" """
openai = "openai" openai = "openai"
openai_multi = "openai_multi"
openai_tool = "openai_tool"
self_ask = "self_ask"
react = "react" react = "react"
rewoo = "rewoo" rewoo = "rewoo"
vanilla = "vanilla" vanilla = "vanilla"
openai_memory = "openai_memory"
@staticmethod @staticmethod
def get_agent_class(_type: "AgentType"): def get_agent_class(_type: "AgentType"):
@ -37,16 +37,6 @@ class AgentType(Enum):
raise ValueError(f"Unknown agent type: {_type}") raise ValueError(f"Unknown agent type: {_type}")
class AgentOutput(BaseModel):
"""
Pydantic model for agent output.
"""
output: str
cost: float
token_usage: int
class BaseAgent(BaseTool): class BaseAgent(BaseTool):
name: str name: str
"""Name of the agent.""" """Name of the agent."""
@ -62,6 +52,10 @@ class BaseAgent(BaseTool):
prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]] prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]]
"""A prompt template or a dict to supply different prompt to the agent """A prompt template or a dict to supply different prompt to the agent
""" """
plugins: List[BaseTool] plugins: List[BaseTool] = []
"""List of plugins / tools to be used in the agent """List of plugins / tools to be used in the agent
""" """
def add_tools(self, tools: List[BaseTool]) -> None:
"""Helper method to add tools and update agent state if needed"""
self.plugins.extend(tools)

View File

@ -0,0 +1,85 @@
from typing import List, Optional, Type
from langchain.agents import AgentType as LCAgentType
from langchain.agents import initialize_agent
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
from pydantic import BaseModel, create_model
from kotaemon.base.schema import Document
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.pipelines.tools import BaseTool
from .base import AgentType, BaseAgent
class LangchainAgent(BaseAgent):
"""Wrapper for Langchain Agent"""
name: str = "LangchainAgent"
agent_type: AgentType
description: str = "LangchainAgent for answering multi-step reasoning questions"
args_schema: Optional[Type[BaseModel]] = create_model(
"LangchainArgsSchema", instruction=(str, ...)
)
AGENT_TYPE_MAP = {
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
}
agent: Optional[LCAgentExecutor] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.agent_type not in self.AGENT_TYPE_MAP:
raise NotImplementedError(
f"AgentType {self.agent_type } not supported by Langchain wrapper"
)
self.update_agent_tools()
def update_agent_tools(self):
assert isinstance(self.llm, (ChatLLM, LLM))
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
# a fix for search_doc tool name:
# use "Intermediate Answer" for self-ask agent
found_search_tool = False
if self.agent_type == AgentType.self_ask:
for plugin in langchain_plugins:
if plugin.name == "search_doc":
plugin.name = "Intermediate Answer"
langchain_plugins = [plugin]
found_search_tool = True
break
if self.agent_type != AgentType.self_ask or found_search_tool:
# reinit Langchain AgentExecutor
self.agent = initialize_agent(
langchain_plugins,
self.llm.agent,
agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True,
verbose=True,
)
def add_tools(self, tools: List[BaseTool]) -> None:
super().add_tools(tools)
self.update_agent_tools()
return
def _run_tool(self, instruction: str) -> Document:
assert (
self.agent is not None
), "Lanchain AgentExecutor is not correclty initialized"
# Langchain AgentExecutor call
output = self.agent(instruction)["output"]
return Document(
text=output,
metadata={
"agent": "langchain",
"cost": 0.0,
"usage": 0,
},
)

View File

@ -9,7 +9,7 @@ from kotaemon.base.schema import Document
from kotaemon.llms import LLM, ChatLLM, PromptTemplate from kotaemon.llms import LLM, ChatLLM, PromptTemplate
from kotaemon.pipelines.citation import CitationPipeline from kotaemon.pipelines.citation import CitationPipeline
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
from ..output.base import BaseScratchPad from ..output.base import BaseScratchPad
from ..utils import get_plugin_response_content from ..utils import get_plugin_response_content
from .planner import Planner from .planner import Planner
@ -28,7 +28,9 @@ class RewooAgent(BaseAgent):
str, PromptTemplate str, PromptTemplate
] = dict() # {"Planner": xxx, "Solver": xxx} ] = dict() # {"Planner": xxx, "Solver": xxx}
plugins: List[BaseTool] = list() plugins: List[BaseTool] = list()
examples: Dict[str, Union[str, List[str]]] = dict() examples: Dict[
str, Union[str, List[str]]
] = dict() # {"Planner": xxx, "Solver": xxx}
args_schema: Optional[Type[BaseModel]] = create_model( args_schema: Optional[Type[BaseModel]] = create_model(
"RewooArgsSchema", instruction=(str, ...) "RewooArgsSchema", instruction=(str, ...)
) )
@ -156,10 +158,6 @@ class RewooAgent(BaseAgent):
if selected_plugin is None: if selected_plugin is None:
raise ValueError("Invalid plugin detected") raise ValueError("Invalid plugin detected")
tool_response = selected_plugin(tool_input) tool_response = selected_plugin(tool_input)
# cumulate agent-as-plugin costs and tokens.
if isinstance(tool_response, AgentOutput):
result["plugin_cost"] = tool_response.cost
result["plugin_token"] = tool_response.token_usage
result["evidence"] = get_plugin_response_content(tool_response) result["evidence"] = get_plugin_response_content(tool_response)
except ValueError: except ValueError:
result["evidence"] = "No evidence found." result["evidence"] = "No evidence found."

View File

@ -73,6 +73,7 @@ class Planner(BaseComponent):
output.debug(f"Prompt: {prompt}") output.debug(f"Prompt: {prompt}")
try: try:
response = self.model(prompt) response = self.model(prompt)
self.log_progress(".planner", response=response)
output.info("Planner run successful.") output.info("Planner run successful.")
except ValueError as e: except ValueError as e:
output.error("Planner failed to retrieve response from LLM") output.error("Planner failed to retrieve response from LLM")

View File

@ -1,12 +1,12 @@
from .base import AgentOutput from ...base import Document
def get_plugin_response_content(output) -> str: def get_plugin_response_content(output) -> str:
""" """
Wrapper for AgentOutput content return Wrapper for AgentOutput content return
""" """
if isinstance(output, AgentOutput): if isinstance(output, Document):
return output.output return output.text
else: else:
return str(output) return str(output)

View File

@ -43,14 +43,14 @@ class ReaderIndexingPipeline(BaseComponent):
reader_name: str = "normal" # "normal", "mathpix" or "ocr" reader_name: str = "normal" # "normal", "mathpix" or "ocr"
chunk_size: int = 1024 chunk_size: int = 1024
chunk_overlap: int = 256 chunk_overlap: int = 256
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore) vector_store: BaseVectorStore = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore) doc_store: BaseDocumentStore = _(InMemoryDocumentStore)
doc_parsers: List[DocParser] = [] doc_parsers: List[DocParser] = []
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding", deployment="dummy-q2-text-embedding",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/", azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
chunk_size=16, chunk_size=16,
) )

View File

@ -49,14 +49,14 @@ class QuestionAnsweringPipeline(BaseComponent):
request_timeout=60, request_timeout=60,
) )
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore) vector_store: BaseVectorStore = _(InMemoryVectorStore)
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore) doc_store: BaseDocumentStore = _(InMemoryDocumentStore)
rerankers: Sequence[BaseRerankingPipeline] = [] rerankers: Sequence[BaseRerankingPipeline] = []
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding", deployment="dummy-q2-text-embedding",
azure_endpoint="https://bleh-dummy-2.openai.azure.com/", azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
) )
@ -137,8 +137,9 @@ class AgentQAPipeline(QuestionAnsweringPipeline):
component=self.retrieving_pipeline, component=self.retrieving_pipeline,
) )
if search_tool not in self.agent.plugins: if search_tool not in self.agent.plugins:
self.agent.plugins.append(search_tool) self.agent.add_tools([search_tool])
def run(self, question: str, use_citation: bool = False) -> Document: def run(self, question: str, use_citation: bool = False) -> Document:
answer = self.agent(question, use_citation=use_citation) kwargs = {"use_citation": use_citation} if use_citation else {}
answer = self.agent(question, **kwargs)
return answer return answer

View File

@ -4,6 +4,8 @@ import pytest
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms.chats.openai import AzureChatOpenAI from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.pipelines.agents.base import AgentType
from kotaemon.pipelines.agents.langchain import LangchainAgent
from kotaemon.pipelines.agents.react import ReactAgent from kotaemon.pipelines.agents.react import ReactAgent
from kotaemon.pipelines.agents.rewoo import RewooAgent from kotaemon.pipelines.agents.rewoo import RewooAgent
from kotaemon.pipelines.tools import ( from kotaemon.pipelines.tools import (
@ -13,7 +15,7 @@ from kotaemon.pipelines.tools import (
WikipediaTool, WikipediaTool,
) )
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!" FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
_openai_chat_completion_responses_rewoo = [ _openai_chat_completion_responses_rewoo = [
@ -199,7 +201,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
agent = initialize_agent( agent = initialize_agent(
langchain_plugins, langchain_plugins,
llm.agent, llm.agent,
agent=AgentType.OPENAI_FUNCTIONS, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True, verbose=True,
) )
response = agent("Tell me about Cinnamon AI company") response = agent("Tell me about Cinnamon AI company")
@ -207,6 +209,26 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
assert response assert response
@patch(
"openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_react,
)
def test_wrapper_agent_langchain(openai_completion, llm, mock_google_search):
plugins = [
GoogleSearchTool(),
WikipediaTool(),
LLMTool(llm=llm),
]
agent = LangchainAgent(
llm=llm,
plugins=plugins,
agent_type=AgentType.react,
)
response = agent("Tell me about Cinnamon AI company")
openai_completion.assert_called()
assert response
@patch( @patch(
"openai.resources.chat.completions.Completions.create", "openai.resources.chat.completions.Completions.create",
side_effect=_openai_chat_completion_responses_react_langchain_tool, side_effect=_openai_chat_completion_responses_react_langchain_tool,