Update Base interface of Index/Retrieval pipeline (#36)
* add base Tool * minor update test_tool * update test dependency * update test dependency * Fix namespace conflict * update test * add base Agent Interface, add ReWoo Agent * minor update * update test * fix typo * remove unneeded print * update rewoo agent * add LLMTool * update BaseAgent type * add ReAct agent * add ReAct agent * minor update * minor update * minor update * minor update * update base reader with BaseComponent * add splitter * update agent and tool * update vectorstores * update load/save for indexing and retrieving pipeline * update test_agent for more use-cases * add missing dependency for test * update test case for in memory vectorstore * add TextSplitter to BaseComponent * update type hint basetool --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
committed by
GitHub
parent
49ed3f6994
commit
56bc41b673
@@ -1,6 +1,7 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||
|
||||
from langchain.agents import Tool as LCTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
@@ -87,6 +88,10 @@ class BaseTool(BaseComponent):
|
||||
)
|
||||
return observation
|
||||
|
||||
def to_langchain_format(self) -> LCTool:
|
||||
"""Convert this tool to Langchain format to use with its agent"""
|
||||
return LCTool(name=self.name, description=self.description, func=self.run)
|
||||
|
||||
def run_raw(
|
||||
self,
|
||||
tool_input: Union[str, Dict],
|
||||
@@ -122,6 +127,15 @@ class BaseTool(BaseComponent):
|
||||
"""Tool does not support processing batch"""
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
|
||||
"""Wrapper for Langchain Tool"""
|
||||
new_tool = BaseTool(
|
||||
name=langchain_tool.name, description=langchain_tool.description
|
||||
)
|
||||
new_tool._run_tool = langchain_tool._run # type: ignore
|
||||
return new_tool
|
||||
|
||||
|
||||
class ComponentTool(BaseTool):
|
||||
"""
|
||||
@@ -130,6 +144,11 @@ class ComponentTool(BaseTool):
|
||||
"""
|
||||
|
||||
component: BaseComponent
|
||||
postprocessor: Optional[Callable] = None
|
||||
|
||||
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.component(*args, **kwargs)
|
||||
output = self.component(*args, **kwargs)
|
||||
if self.postprocessor:
|
||||
output = self.postprocessor(output)
|
||||
|
||||
return output
|
||||
|
Reference in New Issue
Block a user