Create Langchain LLM converter to quickly supply it to Langchain's chain (#102)
* Create Langchain LLM converter to quickly supply it to Langchain's chain * Clean up
This commit is contained in:
parent
da0ac1d69f
commit
0e30dcbb06
|
@ -53,9 +53,7 @@ class LangchainAgent(BaseAgent):
|
|||
# reinit Langchain AgentExecutor
|
||||
self.agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
# TODO: could cause bugs for non-langchain llms
|
||||
# related to https://github.com/Cinnamon/kotaemon/issues/73
|
||||
self.llm._obj, # type: ignore
|
||||
self.llm.to_langchain_format(),
|
||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||
handle_parsing_errors=True,
|
||||
verbose=True,
|
||||
|
|
|
@ -1,16 +1,12 @@
|
|||
from typing import Union
|
||||
|
||||
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from .base import BaseLLM
|
||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||
from .chats import AzureChatOpenAI, ChatLLM
|
||||
from .completions import LLM, AzureOpenAI, OpenAI
|
||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||
from .prompts import BasePromptComponent, PromptTemplate
|
||||
|
||||
BaseLLM = Union[ChatLLM, LLM]
|
||||
|
||||
|
||||
__all__ = [
|
||||
"BaseLLM",
|
||||
# chat-specific components
|
||||
|
|
8
knowledgehub/llms/base.py
Normal file
8
knowledgehub/llms/base.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
from langchain_core.language_models.base import BaseLanguageModel
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
|
||||
|
||||
class BaseLLM(BaseComponent):
|
||||
def to_langchain_format(self) -> BaseLanguageModel:
|
||||
raise NotImplementedError
|
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
|||
import logging
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms.base import BaseLLM
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChatLLM(BaseComponent):
|
||||
class ChatLLM(BaseLLM):
|
||||
def flow(self):
|
||||
if self.inflow is None:
|
||||
raise ValueError("No inflow provided.")
|
||||
|
|
|
@ -68,6 +68,9 @@ class LCChatMixin:
|
|||
logits=[],
|
||||
)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
def __repr__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.llms.base import BaseLLM
|
||||
|
||||
|
||||
class LLM(BaseComponent):
|
||||
class LLM(BaseLLM):
|
||||
pass
|
||||
|
|
|
@ -45,6 +45,9 @@ class LCCompletionMixin:
|
|||
logits=[],
|
||||
)
|
||||
|
||||
def to_langchain_format(self):
|
||||
return self._obj
|
||||
|
||||
def __repr__(self):
|
||||
kwargs = []
|
||||
for key, value_obj in self._kwargs.items():
|
||||
|
|
|
@ -189,7 +189,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
|
|||
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
|
||||
agent = initialize_agent(
|
||||
langchain_plugins,
|
||||
llm._obj,
|
||||
llm.to_langchain_format(),
|
||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||
verbose=True,
|
||||
)
|
||||
|
|
|
@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion):
|
|||
temperature=0,
|
||||
)
|
||||
assert isinstance(
|
||||
model._obj, AzureChatOpenAILC
|
||||
model.to_langchain_format(), AzureChatOpenAILC
|
||||
), "Agent not wrapped in Langchain's AzureChatOpenAI"
|
||||
|
||||
# test for str input - stream mode
|
||||
|
|
|
@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion):
|
|||
request_timeout=60,
|
||||
)
|
||||
assert isinstance(
|
||||
model._obj, AzureOpenAILC
|
||||
model.to_langchain_format(), AzureOpenAILC
|
||||
), "Agent not wrapped in Langchain's AzureOpenAI"
|
||||
|
||||
output = model("hello world")
|
||||
|
@ -64,7 +64,7 @@ def test_openai_model(openai_completion):
|
|||
request_timeout=60,
|
||||
)
|
||||
assert isinstance(
|
||||
model._obj, OpenAILC
|
||||
model.to_langchain_format(), OpenAILC
|
||||
), "Agent is not wrapped in Langchain's OpenAI"
|
||||
|
||||
output = model("hello world")
|
||||
|
|
Loading…
Reference in New Issue
Block a user