Create Langchain LLM converter to quickly supply it to Langchain's chain (#102)
* Create Langchain LLM converter to quickly supply it to Langchain's chain * Clean up
This commit is contained in:
parent
da0ac1d69f
commit
0e30dcbb06
|
@ -53,9 +53,7 @@ class LangchainAgent(BaseAgent):
|
||||||
# reinit Langchain AgentExecutor
|
# reinit Langchain AgentExecutor
|
||||||
self.agent = initialize_agent(
|
self.agent = initialize_agent(
|
||||||
langchain_plugins,
|
langchain_plugins,
|
||||||
# TODO: could cause bugs for non-langchain llms
|
self.llm.to_langchain_format(),
|
||||||
# related to https://github.com/Cinnamon/kotaemon/issues/73
|
|
||||||
self.llm._obj, # type: ignore
|
|
||||||
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
agent=self.AGENT_TYPE_MAP[self.agent_type],
|
||||||
handle_parsing_errors=True,
|
handle_parsing_errors=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
|
|
|
@ -1,16 +1,12 @@
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from .base import BaseLLM
|
||||||
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||||
from .chats import AzureChatOpenAI, ChatLLM
|
from .chats import AzureChatOpenAI, ChatLLM
|
||||||
from .completions import LLM, AzureOpenAI, OpenAI
|
from .completions import LLM, AzureOpenAI, OpenAI
|
||||||
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||||
from .prompts import BasePromptComponent, PromptTemplate
|
from .prompts import BasePromptComponent, PromptTemplate
|
||||||
|
|
||||||
BaseLLM = Union[ChatLLM, LLM]
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseLLM",
|
"BaseLLM",
|
||||||
# chat-specific components
|
# chat-specific components
|
||||||
|
|
8
knowledgehub/llms/base.py
Normal file
8
knowledgehub/llms/base.py
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
from langchain_core.language_models.base import BaseLanguageModel
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
|
class BaseLLM(BaseComponent):
|
||||||
|
def to_langchain_format(self) -> BaseLanguageModel:
|
||||||
|
raise NotImplementedError
|
|
@ -3,11 +3,12 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.llms.base import BaseLLM
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class ChatLLM(BaseComponent):
|
class ChatLLM(BaseLLM):
|
||||||
def flow(self):
|
def flow(self):
|
||||||
if self.inflow is None:
|
if self.inflow is None:
|
||||||
raise ValueError("No inflow provided.")
|
raise ValueError("No inflow provided.")
|
||||||
|
|
|
@ -68,6 +68,9 @@ class LCChatMixin:
|
||||||
logits=[],
|
logits=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_langchain_format(self):
|
||||||
|
return self._obj
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
kwargs = []
|
kwargs = []
|
||||||
for key, value_obj in self._kwargs.items():
|
for key, value_obj in self._kwargs.items():
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.llms.base import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class LLM(BaseComponent):
|
class LLM(BaseLLM):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -45,6 +45,9 @@ class LCCompletionMixin:
|
||||||
logits=[],
|
logits=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def to_langchain_format(self):
|
||||||
|
return self._obj
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
kwargs = []
|
kwargs = []
|
||||||
for key, value_obj in self._kwargs.items():
|
for key, value_obj in self._kwargs.items():
|
||||||
|
|
|
@ -189,7 +189,7 @@ def test_react_agent_langchain(openai_completion, llm, mock_google_search):
|
||||||
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
|
langchain_plugins = [tool.to_langchain_format() for tool in plugins]
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
langchain_plugins,
|
langchain_plugins,
|
||||||
llm._obj,
|
llm.to_langchain_format(),
|
||||||
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
)
|
)
|
||||||
|
|
|
@ -48,7 +48,7 @@ def test_azureopenai_model(openai_completion):
|
||||||
temperature=0,
|
temperature=0,
|
||||||
)
|
)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
model._obj, AzureChatOpenAILC
|
model.to_langchain_format(), AzureChatOpenAILC
|
||||||
), "Agent not wrapped in Langchain's AzureChatOpenAI"
|
), "Agent not wrapped in Langchain's AzureChatOpenAI"
|
||||||
|
|
||||||
# test for str input - stream mode
|
# test for str input - stream mode
|
||||||
|
|
|
@ -41,7 +41,7 @@ def test_azureopenai_model(openai_completion):
|
||||||
request_timeout=60,
|
request_timeout=60,
|
||||||
)
|
)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
model._obj, AzureOpenAILC
|
model.to_langchain_format(), AzureOpenAILC
|
||||||
), "Agent not wrapped in Langchain's AzureOpenAI"
|
), "Agent not wrapped in Langchain's AzureOpenAI"
|
||||||
|
|
||||||
output = model("hello world")
|
output = model("hello world")
|
||||||
|
@ -64,7 +64,7 @@ def test_openai_model(openai_completion):
|
||||||
request_timeout=60,
|
request_timeout=60,
|
||||||
)
|
)
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
model._obj, OpenAILC
|
model.to_langchain_format(), OpenAILC
|
||||||
), "Agent is not wrapped in Langchain's OpenAI"
|
), "Agent is not wrapped in Langchain's OpenAI"
|
||||||
|
|
||||||
output = model("hello world")
|
output = model("hello world")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user