Create Langchain LLM converter to quickly supply it to Langchain's chain (#102)

* Create Langchain LLM converter to quickly supply it to Langchain's chain

* Clean up
This commit is contained in:
Duc Nguyen (john) 2023-12-11 14:55:56 +07:00 committed by GitHub
parent da0ac1d69f
commit 0e30dcbb06
10 changed files with 24 additions and 15 deletions

View File

@ -53,9 +53,7 @@ class LangchainAgent(BaseAgent):
# reinit Langchain AgentExecutor
self.agent = initialize_agent(
langchain_plugins,
# TODO: could cause bugs for non-langchain llms
# related to https://github.com/Cinnamon/kotaemon/issues/73
self.llm._obj, # type: ignore
self.llm.to_langchain_format(),
agent=self.AGENT_TYPE_MAP[self.agent_type],
handle_parsing_errors=True,
verbose=True,

View File

@ -1,16 +1,12 @@
from typing import Union
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, ChatLLM
from .completions import LLM, AzureOpenAI, OpenAI
from .linear import GatedLinearPipeline, SimpleLinearPipeline
from .prompts import BasePromptComponent, PromptTemplate
BaseLLM = Union[ChatLLM, LLM]
__all__ = [
"BaseLLM",
# chat-specific components

View File

@ -0,0 +1,8 @@
from langchain_core.language_models.base import BaseLanguageModel
from kotaemon.base import BaseComponent
class BaseLLM(BaseComponent):
def to_langchain_format(self) -> BaseLanguageModel:
raise NotImplementedError

View File

@ -3,11 +3,12 @@ from __future__ import annotations
import logging
from kotaemon.base import BaseComponent
from kotaemon.llms.base import BaseLLM
logger = logging.getLogger(__name__)
class ChatLLM(BaseComponent):
class ChatLLM(BaseLLM):
def flow(self):
if self.inflow is None:
raise ValueError("No inflow provided.")

View File

@ -68,6 +68,9 @@ class LCChatMixin:
logits=[],
)
def to_langchain_format(self):
return self._obj
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():

View File

@ -1,5 +1,5 @@
from kotaemon.base import BaseComponent
from kotaemon.llms.base import BaseLLM
class LLM(BaseComponent):
class LLM(BaseLLM):
pass

View File

@ -45,6 +45,9 @@ class LCCompletionMixin:
logits=[],
)
def to_langchain_format(self):
return self._obj
def __repr__(self):
kwargs = []
for key, value_obj in self._kwargs.items():

View File

@ -189,7 +189,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
agent = initialize_agent(
langchain_plugins,
llm._obj,
llm.to_langchain_format(),
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
)

View File

@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion):
temperature=0,
)
assert isinstance(
model._obj, AzureChatOpenAILC
model.to_langchain_format(), AzureChatOpenAILC
), "Agent not wrapped in Langchain's AzureChatOpenAI"
# test for str input - stream mode

View File

@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion):
request_timeout=60,
)
assert isinstance(
model._obj, AzureOpenAILC
model.to_langchain_format(), AzureOpenAILC
), "Agent not wrapped in Langchain's AzureOpenAI"
output = model("hello world")
@ -64,7 +64,7 @@ def test_openai_model(openai_completion):
request_timeout=60,
)
assert isinstance(
model._obj, OpenAILC
model.to_langchain_format(), OpenAILC
), "Agent is not wrapped in Langchain's OpenAI"
output = model("hello world")