From 0e30dcbb061d105078fb0ae352456b8682ad0d92 Mon Sep 17 00:00:00 2001 From: "Duc Nguyen (john)" Date: Mon, 11 Dec 2023 14:55:56 +0700 Subject: [PATCH] Create Langchain LLM converter to quickly supply it to Langchain's chain (#102) * Create Langchain LLM converter to quickly supply it to Langchain's chain * Clean up --- knowledgehub/agents/langchain_based.py | 4 +--- knowledgehub/llms/__init__.py | 6 +----- knowledgehub/llms/base.py | 8 ++++++++ knowledgehub/llms/chats/base.py | 3 ++- knowledgehub/llms/chats/langchain_based.py | 3 +++ knowledgehub/llms/completions/base.py | 4 ++-- knowledgehub/llms/completions/langchain_based.py | 3 +++ tests/test_agent.py | 2 +- tests/test_llms_chat_models.py | 2 +- tests/test_llms_completion_models.py | 4 ++-- 10 files changed, 24 insertions(+), 15 deletions(-) create mode 100644 knowledgehub/llms/base.py diff --git a/knowledgehub/agents/langchain_based.py b/knowledgehub/agents/langchain_based.py index 8189b6f..9138a17 100644 --- a/knowledgehub/agents/langchain_based.py +++ b/knowledgehub/agents/langchain_based.py @@ -53,9 +53,7 @@ class LangchainAgent(BaseAgent): # reinit Langchain AgentExecutor self.agent = initialize_agent( langchain_plugins, - # TODO: could cause bugs for non-langchain llms - # related to https://github.com/Cinnamon/kotaemon/issues/73 - self.llm._obj, # type: ignore + self.llm.to_langchain_format(), agent=self.AGENT_TYPE_MAP[self.agent_type], handle_parsing_errors=True, verbose=True, diff --git a/knowledgehub/llms/__init__.py b/knowledgehub/llms/__init__.py index bdc61bc..521f596 100644 --- a/knowledgehub/llms/__init__.py +++ b/knowledgehub/llms/__init__.py @@ -1,16 +1,12 @@ -from typing import Union - from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage +from .base import BaseLLM from .branching import GatedBranchingPipeline, SimpleBranchingPipeline from .chats import AzureChatOpenAI, ChatLLM from .completions import LLM, AzureOpenAI, OpenAI from .linear import GatedLinearPipeline, SimpleLinearPipeline from .prompts import BasePromptComponent, PromptTemplate -BaseLLM = Union[ChatLLM, LLM] - - __all__ = [ "BaseLLM", # chat-specific components diff --git a/knowledgehub/llms/base.py b/knowledgehub/llms/base.py new file mode 100644 index 0000000..ff315ea --- /dev/null +++ b/knowledgehub/llms/base.py @@ -0,0 +1,8 @@ +from langchain_core.language_models.base import BaseLanguageModel + +from kotaemon.base import BaseComponent + + +class BaseLLM(BaseComponent): + def to_langchain_format(self) -> BaseLanguageModel: + raise NotImplementedError diff --git a/knowledgehub/llms/chats/base.py b/knowledgehub/llms/chats/base.py index a5280f4..09042bc 100644 --- a/knowledgehub/llms/chats/base.py +++ b/knowledgehub/llms/chats/base.py @@ -3,11 +3,12 @@ from __future__ import annotations import logging from kotaemon.base import BaseComponent +from kotaemon.llms.base import BaseLLM logger = logging.getLogger(__name__) -class ChatLLM(BaseComponent): +class ChatLLM(BaseLLM): def flow(self): if self.inflow is None: raise ValueError("No inflow provided.") diff --git a/knowledgehub/llms/chats/langchain_based.py b/knowledgehub/llms/chats/langchain_based.py index 1b937c8..7d4eb76 100644 --- a/knowledgehub/llms/chats/langchain_based.py +++ b/knowledgehub/llms/chats/langchain_based.py @@ -68,6 +68,9 @@ class LCChatMixin: logits=[], ) + def to_langchain_format(self): + return self._obj + def __repr__(self): kwargs = [] for key, value_obj in self._kwargs.items(): diff --git a/knowledgehub/llms/completions/base.py b/knowledgehub/llms/completions/base.py index e4a8540..004fddb 100644 --- a/knowledgehub/llms/completions/base.py +++ b/knowledgehub/llms/completions/base.py @@ -1,5 +1,5 @@ -from kotaemon.base import BaseComponent +from kotaemon.llms.base import BaseLLM -class LLM(BaseComponent): +class LLM(BaseLLM): pass diff --git a/knowledgehub/llms/completions/langchain_based.py b/knowledgehub/llms/completions/langchain_based.py index 97b2bda..0048ef6 100644 --- a/knowledgehub/llms/completions/langchain_based.py +++ b/knowledgehub/llms/completions/langchain_based.py @@ -45,6 +45,9 @@ class LCCompletionMixin: logits=[], ) + def to_langchain_format(self): + return self._obj + def __repr__(self): kwargs = [] for key, value_obj in self._kwargs.items(): diff --git a/tests/test_agent.py b/tests/test_agent.py index 4a10060..02e5da9 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -189,7 +189,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search): langchain_plugins = [tool.to_langchain_format() for tool in plugins] agent = initialize_agent( langchain_plugins, - llm._obj, + llm.to_langchain_format(), agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True, ) diff --git a/tests/test_llms_chat_models.py b/tests/test_llms_chat_models.py index 54c94ce..e7336d6 100644 --- a/tests/test_llms_chat_models.py +++ b/tests/test_llms_chat_models.py @@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion): temperature=0, ) assert isinstance( - model._obj, AzureChatOpenAILC + model.to_langchain_format(), AzureChatOpenAILC ), "Agent not wrapped in Langchain's AzureChatOpenAI" # test for str input - stream mode diff --git a/tests/test_llms_completion_models.py b/tests/test_llms_completion_models.py index e084935..ea56782 100644 --- a/tests/test_llms_completion_models.py +++ b/tests/test_llms_completion_models.py @@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion): request_timeout=60, ) assert isinstance( - model._obj, AzureOpenAILC + model.to_langchain_format(), AzureOpenAILC ), "Agent not wrapped in Langchain's AzureOpenAI" output = model("hello world") @@ -64,7 +64,7 @@ def test_openai_model(openai_completion): request_timeout=60, ) assert isinstance( - model._obj, OpenAILC + model.to_langchain_format(), OpenAILC ), "Agent is not wrapped in Langchain's OpenAI" output = model("hello world")