From 8bb7ad91e0e2ae0cb112299a68f4cf8542b9f72e Mon Sep 17 00:00:00 2001 From: "Tuan Anh Nguyen Dang (Tadashi_Cin)" Date: Mon, 20 Nov 2023 16:26:08 +0700 Subject: [PATCH] Add Langchain Agent wrapper with OpenAI Function / Self-ask agent support (#82) * update Param() type hint in MVP * update default embedding endpoint * update Langchain agent wrapper * update langchain agent --- knowledgehub/pipelines/agents/__init__.py | 5 +- knowledgehub/pipelines/agents/base.py | 22 ++--- knowledgehub/pipelines/agents/langchain.py | 85 +++++++++++++++++++ knowledgehub/pipelines/agents/rewoo/agent.py | 10 +-- .../pipelines/agents/rewoo/planner.py | 1 + knowledgehub/pipelines/agents/utils.py | 6 +- knowledgehub/pipelines/ingest.py | 6 +- knowledgehub/pipelines/qa.py | 11 +-- tests/test_agent.py | 26 +++++- 9 files changed, 137 insertions(+), 35 deletions(-) create mode 100644 knowledgehub/pipelines/agents/langchain.py diff --git a/knowledgehub/pipelines/agents/__init__.py b/knowledgehub/pipelines/agents/__init__.py index a9caa08..349173c 100644 --- a/knowledgehub/pipelines/agents/__init__.py +++ b/knowledgehub/pipelines/agents/__init__.py @@ -1,5 +1,6 @@ -from .base import BaseAgent +from .base import AgentType, BaseAgent +from .langchain import LangchainAgent from .react.agent import ReactAgent from .rewoo.agent import RewooAgent -__all__ = ["BaseAgent", "ReactAgent", "RewooAgent"] +__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"] diff --git a/knowledgehub/pipelines/agents/base.py b/knowledgehub/pipelines/agents/base.py index cf8f2fe..c8f6dad 100644 --- a/knowledgehub/pipelines/agents/base.py +++ b/knowledgehub/pipelines/agents/base.py @@ -1,8 +1,6 @@ from enum import Enum from typing import Dict, List, Optional, Union -from pydantic import BaseModel - from kotaemon.llms import PromptTemplate from kotaemon.llms.chats.base import ChatLLM from kotaemon.llms.completions.base import LLM @@ -17,10 +15,12 @@ class AgentType(Enum): """ openai = "openai" + openai_multi = "openai_multi" + openai_tool = "openai_tool" + self_ask = "self_ask" react = "react" rewoo = "rewoo" vanilla = "vanilla" - openai_memory = "openai_memory" @staticmethod def get_agent_class(_type: "AgentType"): @@ -37,16 +37,6 @@ class AgentType(Enum): raise ValueError(f"Unknown agent type: {_type}") -class AgentOutput(BaseModel): - """ - Pydantic model for agent output. - """ - - output: str - cost: float - token_usage: int - - class BaseAgent(BaseTool): name: str """Name of the agent.""" @@ -62,6 +52,10 @@ class BaseAgent(BaseTool): prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]] """A prompt template or a dict to supply different prompt to the agent """ - plugins: List[BaseTool] + plugins: List[BaseTool] = [] """List of plugins / tools to be used in the agent """ + + def add_tools(self, tools: List[BaseTool]) -> None: + """Helper method to add tools and update agent state if needed""" + self.plugins.extend(tools) diff --git a/knowledgehub/pipelines/agents/langchain.py b/knowledgehub/pipelines/agents/langchain.py new file mode 100644 index 0000000..7048618 --- /dev/null +++ b/knowledgehub/pipelines/agents/langchain.py @@ -0,0 +1,85 @@ +from typing import List, Optional, Type + +from langchain.agents import AgentType as LCAgentType +from langchain.agents import initialize_agent +from langchain.agents.agent import AgentExecutor as LCAgentExecutor +from pydantic import BaseModel, create_model + +from kotaemon.base.schema import Document +from kotaemon.llms.chats.base import ChatLLM +from kotaemon.llms.completions.base import LLM +from kotaemon.pipelines.tools import BaseTool + +from .base import AgentType, BaseAgent + + +class LangchainAgent(BaseAgent): + """Wrapper for Langchain Agent""" + + name: str = "LangchainAgent" + agent_type: AgentType + description: str = "LangchainAgent for answering multi-step reasoning questions" + args_schema: Optional[Type[BaseModel]] = create_model( + "LangchainArgsSchema", instruction=(str, ...) + ) + AGENT_TYPE_MAP = { + AgentType.openai: LCAgentType.OPENAI_FUNCTIONS, + AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS, + AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION, + AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH, + } + agent: Optional[LCAgentExecutor] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.agent_type not in self.AGENT_TYPE_MAP: + raise NotImplementedError( + f"AgentType {self.agent_type } not supported by Langchain wrapper" + ) + self.update_agent_tools() + + def update_agent_tools(self): + assert isinstance(self.llm, (ChatLLM, LLM)) + langchain_plugins = [tool.to_langchain_format() for tool in self.plugins] + + # a fix for search_doc tool name: + # use "Intermediate Answer" for self-ask agent + found_search_tool = False + if self.agent_type == AgentType.self_ask: + for plugin in langchain_plugins: + if plugin.name == "search_doc": + plugin.name = "Intermediate Answer" + langchain_plugins = [plugin] + found_search_tool = True + break + + if self.agent_type != AgentType.self_ask or found_search_tool: + # reinit Langchain AgentExecutor + self.agent = initialize_agent( + langchain_plugins, + self.llm.agent, + agent=self.AGENT_TYPE_MAP[self.agent_type], + handle_parsing_errors=True, + verbose=True, + ) + + def add_tools(self, tools: List[BaseTool]) -> None: + super().add_tools(tools) + self.update_agent_tools() + return + + def _run_tool(self, instruction: str) -> Document: + assert ( + self.agent is not None + ), "Lanchain AgentExecutor is not correclty initialized" + # Langchain AgentExecutor call + output = self.agent(instruction)["output"] + return Document( + text=output, + metadata={ + "agent": "langchain", + "cost": 0.0, + "usage": 0, + }, + ) diff --git a/knowledgehub/pipelines/agents/rewoo/agent.py b/knowledgehub/pipelines/agents/rewoo/agent.py index 1dcc5ae..d31b629 100644 --- a/knowledgehub/pipelines/agents/rewoo/agent.py +++ b/knowledgehub/pipelines/agents/rewoo/agent.py @@ -9,7 +9,7 @@ from kotaemon.base.schema import Document from kotaemon.llms import LLM, ChatLLM, PromptTemplate from kotaemon.pipelines.citation import CitationPipeline -from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool +from ..base import AgentType, BaseAgent, BaseLLM, BaseTool from ..output.base import BaseScratchPad from ..utils import get_plugin_response_content from .planner import Planner @@ -28,7 +28,9 @@ class RewooAgent(BaseAgent): str, PromptTemplate ] = dict() # {"Planner": xxx, "Solver": xxx} plugins: List[BaseTool] = list() - examples: Dict[str, Union[str, List[str]]] = dict() + examples: Dict[ + str, Union[str, List[str]] + ] = dict() # {"Planner": xxx, "Solver": xxx} args_schema: Optional[Type[BaseModel]] = create_model( "RewooArgsSchema", instruction=(str, ...) ) @@ -156,10 +158,6 @@ class RewooAgent(BaseAgent): if selected_plugin is None: raise ValueError("Invalid plugin detected") tool_response = selected_plugin(tool_input) - # cumulate agent-as-plugin costs and tokens. - if isinstance(tool_response, AgentOutput): - result["plugin_cost"] = tool_response.cost - result["plugin_token"] = tool_response.token_usage result["evidence"] = get_plugin_response_content(tool_response) except ValueError: result["evidence"] = "No evidence found." diff --git a/knowledgehub/pipelines/agents/rewoo/planner.py b/knowledgehub/pipelines/agents/rewoo/planner.py index c0624d8..5075f87 100644 --- a/knowledgehub/pipelines/agents/rewoo/planner.py +++ b/knowledgehub/pipelines/agents/rewoo/planner.py @@ -73,6 +73,7 @@ class Planner(BaseComponent): output.debug(f"Prompt: {prompt}") try: response = self.model(prompt) + self.log_progress(".planner", response=response) output.info("Planner run successful.") except ValueError as e: output.error("Planner failed to retrieve response from LLM") diff --git a/knowledgehub/pipelines/agents/utils.py b/knowledgehub/pipelines/agents/utils.py index 4845526..58cd2e5 100644 --- a/knowledgehub/pipelines/agents/utils.py +++ b/knowledgehub/pipelines/agents/utils.py @@ -1,12 +1,12 @@ -from .base import AgentOutput +from ...base import Document def get_plugin_response_content(output) -> str: """ Wrapper for AgentOutput content return """ - if isinstance(output, AgentOutput): - return output.output + if isinstance(output, Document): + return output.text else: return str(output) diff --git a/knowledgehub/pipelines/ingest.py b/knowledgehub/pipelines/ingest.py index af56853..da55189 100644 --- a/knowledgehub/pipelines/ingest.py +++ b/knowledgehub/pipelines/ingest.py @@ -43,14 +43,14 @@ class ReaderIndexingPipeline(BaseComponent): reader_name: str = "normal" # "normal", "mathpix" or "ocr" chunk_size: int = 1024 chunk_overlap: int = 256 - vector_store: _[BaseVectorStore] = _(InMemoryVectorStore) - doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore) + vector_store: BaseVectorStore = _(InMemoryVectorStore) + doc_store: BaseDocumentStore = _(InMemoryDocumentStore) doc_parsers: List[DocParser] = [] embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( model="text-embedding-ada-002", deployment="dummy-q2-text-embedding", - azure_endpoint="https://bleh-dummy-2.openai.azure.com/", + azure_endpoint="https://bleh-dummy.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), chunk_size=16, ) diff --git a/knowledgehub/pipelines/qa.py b/knowledgehub/pipelines/qa.py index 7bb7322..b370648 100644 --- a/knowledgehub/pipelines/qa.py +++ b/knowledgehub/pipelines/qa.py @@ -49,14 +49,14 @@ class QuestionAnsweringPipeline(BaseComponent): request_timeout=60, ) - vector_store: _[BaseVectorStore] = _(InMemoryVectorStore) - doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore) + vector_store: BaseVectorStore = _(InMemoryVectorStore) + doc_store: BaseDocumentStore = _(InMemoryDocumentStore) rerankers: Sequence[BaseRerankingPipeline] = [] embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( model="text-embedding-ada-002", deployment="dummy-q2-text-embedding", - azure_endpoint="https://bleh-dummy-2.openai.azure.com/", + azure_endpoint="https://bleh-dummy.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), ) @@ -137,8 +137,9 @@ class AgentQAPipeline(QuestionAnsweringPipeline): component=self.retrieving_pipeline, ) if search_tool not in self.agent.plugins: - self.agent.plugins.append(search_tool) + self.agent.add_tools([search_tool]) def run(self, question: str, use_citation: bool = False) -> Document: - answer = self.agent(question, use_citation=use_citation) + kwargs = {"use_citation": use_citation} if use_citation else {} + answer = self.agent(question, **kwargs) return answer diff --git a/tests/test_agent.py b/tests/test_agent.py index 74f8fa3..2c1673d 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -4,6 +4,8 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.pipelines.agents.base import AgentType +from kotaemon.pipelines.agents.langchain import LangchainAgent from kotaemon.pipelines.agents.react import ReactAgent from kotaemon.pipelines.agents.rewoo import RewooAgent from kotaemon.pipelines.tools import ( @@ -13,7 +15,7 @@ from kotaemon.pipelines.tools import ( WikipediaTool, ) -FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!" +FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" _openai_chat_completion_responses_rewoo = [ @@ -199,7 +201,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search): agent = initialize_agent( langchain_plugins, llm.agent, - agent=AgentType.OPENAI_FUNCTIONS, + agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, ) response = agent("Tell me about Cinnamon AI company") @@ -207,6 +209,26 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search): assert response +@patch( + "openai.resources.chat.completions.Completions.create", + side_effect=_openai_chat_completion_responses_react, +) +def test_wrapper_agent_langchain(openai_completion, llm, mock_google_search): + plugins = [ + GoogleSearchTool(), + WikipediaTool(), + LLMTool(llm=llm), + ] + agent = LangchainAgent( + llm=llm, + plugins=plugins, + agent_type=AgentType.react, + ) + response = agent("Tell me about Cinnamon AI company") + openai_completion.assert_called() + assert response + + @patch( "openai.resources.chat.completions.Completions.create", side_effect=_openai_chat_completion_responses_react_langchain_tool,