Create Langchain LLM converter to quickly supply it to Langchain's chain (#102)

* Create Langchain LLM converter to quickly supply it to Langchain's chain

* Clean up
This commit is contained in:
Duc Nguyen (john) 2023-12-11 14:55:56 +07:00 committed by GitHub
parent da0ac1d69f
commit 0e30dcbb06
10 changed files with 24 additions and 15 deletions

View File

@ -53,9 +53,7 @@ class LangchainAgent(BaseAgent):
# reinit Langchain AgentExecutor # reinit Langchain AgentExecutor
self.agent = initialize_agent( self.agent = initialize_agent(
langchain_plugins, langchain_plugins,
# TODO: could cause bugs for non-langchain llms self.llm.to_langchain_format(),
# related to https://github.com/Cinnamon/kotaemon/issues/73
self.llm._obj, # type: ignore
agent=self.AGENT_TYPE_MAP[self.agent_type], agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True, handle_parsing_errors=True,
verbose=True, verbose=True,

View File

@ -1,16 +1,12 @@
from typing import Union
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, ChatLLM from .chats import AzureChatOpenAI, ChatLLM
from .completions import LLM, AzureOpenAI, OpenAI from .completions import LLM, AzureOpenAI, OpenAI
from .linear import GatedLinearPipeline, SimpleLinearPipeline from .linear import GatedLinearPipeline, SimpleLinearPipeline
from .prompts import BasePromptComponent, PromptTemplate from .prompts import BasePromptComponent, PromptTemplate
BaseLLM = Union[ChatLLM, LLM]
__all__ = [ __all__ = [
"BaseLLM", "BaseLLM",
# chat-specific components # chat-specific components

View File

@ -0,0 +1,8 @@
from langchain_core.language_models.base import BaseLanguageModel
from kotaemon.base import BaseComponent
class BaseLLM(BaseComponent):
def to_langchain_format(self) -> BaseLanguageModel:
raise NotImplementedError

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import logging import logging
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.llms.base import BaseLLM
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ChatLLM(BaseComponent): class ChatLLM(BaseLLM):
def flow(self): def flow(self):
if self.inflow is None: if self.inflow is None:
raise ValueError("No inflow provided.") raise ValueError("No inflow provided.")

View File

@ -68,6 +68,9 @@ class LCChatMixin:
logits=[], logits=[],
) )
def to_langchain_format(self):
return self._obj
def __repr__(self): def __repr__(self):
kwargs = [] kwargs = []
for key, value_obj in self._kwargs.items(): for key, value_obj in self._kwargs.items():

View File

@ -1,5 +1,5 @@
from kotaemon.base import BaseComponent from kotaemon.llms.base import BaseLLM
class LLM(BaseComponent): class LLM(BaseLLM):
pass pass

View File

@ -45,6 +45,9 @@ class LCCompletionMixin:
logits=[], logits=[],
) )
def to_langchain_format(self):
return self._obj
def __repr__(self): def __repr__(self):
kwargs = [] kwargs = []
for key, value_obj in self._kwargs.items(): for key, value_obj in self._kwargs.items():

View File

@ -189,7 +189,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
langchain_plugins = [tool.to_langchain_format() for tool in plugins] langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent( agent = initialize_agent(
langchain_plugins, langchain_plugins,
llm._obj, llm.to_langchain_format(),
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True, verbose=True,
) )

View File

@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion):
temperature=0, temperature=0,
) )
assert isinstance( assert isinstance(
model._obj, AzureChatOpenAILC model.to_langchain_format(), AzureChatOpenAILC
), "Agent not wrapped in Langchain's AzureChatOpenAI" ), "Agent not wrapped in Langchain's AzureChatOpenAI"
# test for str input - stream mode # test for str input - stream mode

View File

@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion):
request_timeout=60, request_timeout=60,
) )
assert isinstance( assert isinstance(
model._obj, AzureOpenAILC model.to_langchain_format(), AzureOpenAILC
), "Agent not wrapped in Langchain's AzureOpenAI" ), "Agent not wrapped in Langchain's AzureOpenAI"
output = model("hello world") output = model("hello world")
@ -64,7 +64,7 @@ def test_openai_model(openai_completion):
request_timeout=60, request_timeout=60,
) )
assert isinstance( assert isinstance(
model._obj, OpenAILC model.to_langchain_format(), OpenAILC
), "Agent is not wrapped in Langchain's OpenAI" ), "Agent is not wrapped in Langchain's OpenAI"
output = model("hello world") output = model("hello world")