Add Langchain Agent wrapper with OpenAI Function / Self-ask agent support (#82)
* update Param() type hint in MVP * update default embedding endpoint * update Langchain agent wrapper * update langchain agent
This commit is contained in:
parent
0a3fc4b228
commit
8bb7ad91e0
|
@ -1,5 +1,6 @@
|
||||||
from .base import BaseAgent
|
from .base import AgentType, BaseAgent
|
||||||
|
from .langchain import LangchainAgent
|
||||||
from .react.agent import ReactAgent
|
from .react.agent import ReactAgent
|
||||||
from .rewoo.agent import RewooAgent
|
from .rewoo.agent import RewooAgent
|
||||||
|
|
||||||
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent"]
|
__all__ = ["BaseAgent", "ReactAgent", "RewooAgent", "LangchainAgent", "AgentType"]
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from kotaemon.llms import PromptTemplate
|
from kotaemon.llms import PromptTemplate
|
||||||
from kotaemon.llms.chats.base import ChatLLM
|
from kotaemon.llms.chats.base import ChatLLM
|
||||||
from kotaemon.llms.completions.base import LLM
|
from kotaemon.llms.completions.base import LLM
|
||||||
|
@ -17,10 +15,12 @@ class AgentType(Enum):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
openai = "openai"
|
openai = "openai"
|
||||||
|
openai_multi = "openai_multi"
|
||||||
|
openai_tool = "openai_tool"
|
||||||
|
self_ask = "self_ask"
|
||||||
react = "react"
|
react = "react"
|
||||||
rewoo = "rewoo"
|
rewoo = "rewoo"
|
||||||
vanilla = "vanilla"
|
vanilla = "vanilla"
|
||||||
openai_memory = "openai_memory"
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_agent_class(_type: "AgentType"):
|
def get_agent_class(_type: "AgentType"):
|
||||||
|
@ -37,16 +37,6 @@ class AgentType(Enum):
|
||||||
raise ValueError(f"Unknown agent type: {_type}")
|
raise ValueError(f"Unknown agent type: {_type}")
|
||||||
|
|
||||||
|
|
||||||
class AgentOutput(BaseModel):
|
|
||||||
"""
|
|
||||||
Pydantic model for agent output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
output: str
|
|
||||||
cost: float
|
|
||||||
token_usage: int
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(BaseTool):
|
class BaseAgent(BaseTool):
|
||||||
name: str
|
name: str
|
||||||
"""Name of the agent."""
|
"""Name of the agent."""
|
||||||
|
@ -62,6 +52,10 @@ class BaseAgent(BaseTool):
|
||||||
prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]]
|
prompt_template: Optional[Union[PromptTemplate, Dict[str, PromptTemplate]]]
|
||||||
"""A prompt template or a dict to supply different prompt to the agent
|
"""A prompt template or a dict to supply different prompt to the agent
|
||||||
"""
|
"""
|
||||||
plugins: List[BaseTool]
|
plugins: List[BaseTool] = []
|
||||||
"""List of plugins / tools to be used in the agent
|
"""List of plugins / tools to be used in the agent
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def add_tools(self, tools: List[BaseTool]) -> None:
|
||||||
|
"""Helper method to add tools and update agent state if needed"""
|
||||||
|
self.plugins.extend(tools)
|
||||||
|
|
85
knowledgehub/pipelines/agents/langchain.py
Normal file
85
knowledgehub/pipelines/agents/langchain.py
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
from typing import List, Optional, Type
|
||||||
|
|
||||||
|
from langchain.agents import AgentType as LCAgentType
|
||||||
|
from langchain.agents import initialize_agent
|
||||||
|
from langchain.agents.agent import AgentExecutor as LCAgentExecutor
|
||||||
|
from pydantic import BaseModel, create_model
|
||||||
|
|
||||||
|
from kotaemon.base.schema import Document
|
||||||
|
from kotaemon.llms.chats.base import ChatLLM
|
||||||
|
from kotaemon.llms.completions.base import LLM
|
||||||
|
from kotaemon.pipelines.tools import BaseTool
|
||||||
|
|
||||||
|
from .base import AgentType, BaseAgent
|
||||||
|
|
||||||
|
|
||||||
|
class LangchainAgent(BaseAgent):
|
||||||
|
"""Wrapper for Langchain Agent"""
|
||||||
|
|
||||||
|
name: str = "LangchainAgent"
|
||||||
|
agent_type: AgentType
|
||||||
|
description: str = "LangchainAgent for answering multi-step reasoning questions"
|
||||||
|
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||||
|
"LangchainArgsSchema", instruction=(str, ...)
|
||||||
|
)
|
||||||
|
AGENT_TYPE_MAP = {
|
||||||
|
AgentType.openai: LCAgentType.OPENAI_FUNCTIONS,
|
||||||
|
AgentType.openai_multi: LCAgentType.OPENAI_MULTI_FUNCTIONS,
|
||||||
|
AgentType.react: LCAgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
|
AgentType.self_ask: LCAgentType.SELF_ASK_WITH_SEARCH,
|
||||||
|
}
|
||||||
|
agent: Optional[LCAgentExecutor] = None
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
if self.agent_type not in self.AGENT_TYPE_MAP:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"AgentType {self.agent_type } not supported by Langchain wrapper"
|
||||||
|
)
|
||||||
|
self.update_agent_tools()
|
||||||
|
|
||||||
|
def update_agent_tools(self):
|
||||||
|
assert isinstance(self.llm, (ChatLLM, LLM))
|
||||||
|
langchain_plugins = [tool.to_langchain_format() for tool in self.plugins]
|
||||||
|
|
||||||
|
# a fix for search_doc tool name:
|
||||||
|
# use "Intermediate Answer" for self-ask agent
|
||||||
|
found_search_tool = False
|
||||||
|
if self.agent_type == AgentType.self_ask:
|
||||||
|
for plugin in langchain_plugins:
|
||||||
|
if plugin.name == "search_doc":
|
||||||
|
plugin.name = "Intermediate Answer"
|
||||||
|
langchain_plugins = [plugin]
|
||||||
|
found_search_tool = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if self.agent_type != AgentType.self_ask or found_search_tool:
|
||||||
|
# reinit Langchain AgentExecutor
|
||||||
|
self.agent = initialize_agent(
|
||||||
|
langchain_plugins,
|
||||||
|
self.llm.agent,
|
||||||
|
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||||
|
handle_parsing_errors=True,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_tools(self, tools: List[BaseTool]) -> None:
|
||||||
|
super().add_tools(tools)
|
||||||
|
self.update_agent_tools()
|
||||||
|
return
|
||||||
|
|
||||||
|
def _run_tool(self, instruction: str) -> Document:
|
||||||
|
assert (
|
||||||
|
self.agent is not None
|
||||||
|
), "Lanchain AgentExecutor is not correclty initialized"
|
||||||
|
# Langchain AgentExecutor call
|
||||||
|
output = self.agent(instruction)["output"]
|
||||||
|
return Document(
|
||||||
|
text=output,
|
||||||
|
metadata={
|
||||||
|
"agent": "langchain",
|
||||||
|
"cost": 0.0,
|
||||||
|
"usage": 0,
|
||||||
|
},
|
||||||
|
)
|
|
@ -9,7 +9,7 @@ from kotaemon.base.schema import Document
|
||||||
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
from kotaemon.llms import LLM, ChatLLM, PromptTemplate
|
||||||
from kotaemon.pipelines.citation import CitationPipeline
|
from kotaemon.pipelines.citation import CitationPipeline
|
||||||
|
|
||||||
from ..base import AgentOutput, AgentType, BaseAgent, BaseLLM, BaseTool
|
from ..base import AgentType, BaseAgent, BaseLLM, BaseTool
|
||||||
from ..output.base import BaseScratchPad
|
from ..output.base import BaseScratchPad
|
||||||
from ..utils import get_plugin_response_content
|
from ..utils import get_plugin_response_content
|
||||||
from .planner import Planner
|
from .planner import Planner
|
||||||
|
@ -28,7 +28,9 @@ class RewooAgent(BaseAgent):
|
||||||
str, PromptTemplate
|
str, PromptTemplate
|
||||||
] = dict() # {"Planner": xxx, "Solver": xxx}
|
] = dict() # {"Planner": xxx, "Solver": xxx}
|
||||||
plugins: List[BaseTool] = list()
|
plugins: List[BaseTool] = list()
|
||||||
examples: Dict[str, Union[str, List[str]]] = dict()
|
examples: Dict[
|
||||||
|
str, Union[str, List[str]]
|
||||||
|
] = dict() # {"Planner": xxx, "Solver": xxx}
|
||||||
args_schema: Optional[Type[BaseModel]] = create_model(
|
args_schema: Optional[Type[BaseModel]] = create_model(
|
||||||
"RewooArgsSchema", instruction=(str, ...)
|
"RewooArgsSchema", instruction=(str, ...)
|
||||||
)
|
)
|
||||||
|
@ -156,10 +158,6 @@ class RewooAgent(BaseAgent):
|
||||||
if selected_plugin is None:
|
if selected_plugin is None:
|
||||||
raise ValueError("Invalid plugin detected")
|
raise ValueError("Invalid plugin detected")
|
||||||
tool_response = selected_plugin(tool_input)
|
tool_response = selected_plugin(tool_input)
|
||||||
# cumulate agent-as-plugin costs and tokens.
|
|
||||||
if isinstance(tool_response, AgentOutput):
|
|
||||||
result["plugin_cost"] = tool_response.cost
|
|
||||||
result["plugin_token"] = tool_response.token_usage
|
|
||||||
result["evidence"] = get_plugin_response_content(tool_response)
|
result["evidence"] = get_plugin_response_content(tool_response)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
result["evidence"] = "No evidence found."
|
result["evidence"] = "No evidence found."
|
||||||
|
|
|
@ -73,6 +73,7 @@ class Planner(BaseComponent):
|
||||||
output.debug(f"Prompt: {prompt}")
|
output.debug(f"Prompt: {prompt}")
|
||||||
try:
|
try:
|
||||||
response = self.model(prompt)
|
response = self.model(prompt)
|
||||||
|
self.log_progress(".planner", response=response)
|
||||||
output.info("Planner run successful.")
|
output.info("Planner run successful.")
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
output.error("Planner failed to retrieve response from LLM")
|
output.error("Planner failed to retrieve response from LLM")
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
from .base import AgentOutput
|
from ...base import Document
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_response_content(output) -> str:
|
def get_plugin_response_content(output) -> str:
|
||||||
"""
|
"""
|
||||||
Wrapper for AgentOutput content return
|
Wrapper for AgentOutput content return
|
||||||
"""
|
"""
|
||||||
if isinstance(output, AgentOutput):
|
if isinstance(output, Document):
|
||||||
return output.output
|
return output.text
|
||||||
else:
|
else:
|
||||||
return str(output)
|
return str(output)
|
||||||
|
|
||||||
|
|
|
@ -43,14 +43,14 @@ class ReaderIndexingPipeline(BaseComponent):
|
||||||
reader_name: str = "normal" # "normal", "mathpix" or "ocr"
|
reader_name: str = "normal" # "normal", "mathpix" or "ocr"
|
||||||
chunk_size: int = 1024
|
chunk_size: int = 1024
|
||||||
chunk_overlap: int = 256
|
chunk_overlap: int = 256
|
||||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
vector_store: BaseVectorStore = _(InMemoryVectorStore)
|
||||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
doc_store: BaseDocumentStore = _(InMemoryDocumentStore)
|
||||||
doc_parsers: List[DocParser] = []
|
doc_parsers: List[DocParser] = []
|
||||||
|
|
||||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="dummy-q2-text-embedding",
|
deployment="dummy-q2-text-embedding",
|
||||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
chunk_size=16,
|
chunk_size=16,
|
||||||
)
|
)
|
||||||
|
|
|
@ -49,14 +49,14 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
request_timeout=60,
|
request_timeout=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_store: _[BaseVectorStore] = _(InMemoryVectorStore)
|
vector_store: BaseVectorStore = _(InMemoryVectorStore)
|
||||||
doc_store: _[BaseDocumentStore] = _(InMemoryDocumentStore)
|
doc_store: BaseDocumentStore = _(InMemoryDocumentStore)
|
||||||
rerankers: Sequence[BaseRerankingPipeline] = []
|
rerankers: Sequence[BaseRerankingPipeline] = []
|
||||||
|
|
||||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="dummy-q2-text-embedding",
|
deployment="dummy-q2-text-embedding",
|
||||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -137,8 +137,9 @@ class AgentQAPipeline(QuestionAnsweringPipeline):
|
||||||
component=self.retrieving_pipeline,
|
component=self.retrieving_pipeline,
|
||||||
)
|
)
|
||||||
if search_tool not in self.agent.plugins:
|
if search_tool not in self.agent.plugins:
|
||||||
self.agent.plugins.append(search_tool)
|
self.agent.add_tools([search_tool])
|
||||||
|
|
||||||
def run(self, question: str, use_citation: bool = False) -> Document:
|
def run(self, question: str, use_citation: bool = False) -> Document:
|
||||||
answer = self.agent(question, use_citation=use_citation)
|
kwargs = {"use_citation": use_citation} if use_citation else {}
|
||||||
|
answer = self.agent(question, **kwargs)
|
||||||
return answer
|
return answer
|
||||||
|
|
|
@ -4,6 +4,8 @@ import pytest
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
|
|
||||||
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.pipelines.agents.base import AgentType
|
||||||
|
from kotaemon.pipelines.agents.langchain import LangchainAgent
|
||||||
from kotaemon.pipelines.agents.react import ReactAgent
|
from kotaemon.pipelines.agents.react import ReactAgent
|
||||||
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
from kotaemon.pipelines.agents.rewoo import RewooAgent
|
||||||
from kotaemon.pipelines.tools import (
|
from kotaemon.pipelines.tools import (
|
||||||
|
@ -13,7 +15,7 @@ from kotaemon.pipelines.tools import (
|
||||||
WikipediaTool,
|
WikipediaTool,
|
||||||
)
|
)
|
||||||
|
|
||||||
FINAL_RESPONSE_TEXT = "Hello Cinnamon AI!"
|
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
|
||||||
|
|
||||||
|
|
||||||
_openai_chat_completion_responses_rewoo = [
|
_openai_chat_completion_responses_rewoo = [
|
||||||
|
@ -199,7 +201,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
langchain_plugins,
|
langchain_plugins,
|
||||||
llm.agent,
|
llm.agent,
|
||||||
agent=AgentType.OPENAI_FUNCTIONS,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
response = agent("Tell me about Cinnamon AI company")
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
|
@ -207,6 +209,26 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
|
||||||
assert response
|
assert response
|
||||||
|
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"openai.resources.chat.completions.Completions.create",
|
||||||
|
side_effect=_openai_chat_completion_responses_react,
|
||||||
|
)
|
||||||
|
def test_wrapper_agent_langchain(openai_completion, llm, mock_google_search):
|
||||||
|
plugins = [
|
||||||
|
GoogleSearchTool(),
|
||||||
|
WikipediaTool(),
|
||||||
|
LLMTool(llm=llm),
|
||||||
|
]
|
||||||
|
agent = LangchainAgent(
|
||||||
|
llm=llm,
|
||||||
|
plugins=plugins,
|
||||||
|
agent_type=AgentType.react,
|
||||||
|
)
|
||||||
|
response = agent("Tell me about Cinnamon AI company")
|
||||||
|
openai_completion.assert_called()
|
||||||
|
assert response
|
||||||
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"openai.resources.chat.completions.Completions.create",
|
"openai.resources.chat.completions.Completions.create",
|
||||||
side_effect=_openai_chat_completion_responses_react_langchain_tool,
|
side_effect=_openai_chat_completion_responses_react_langchain_tool,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user