Update Base interface of Index/Retrieval pipeline (#36)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

* add base Agent Interface, add ReWoo Agent

* minor update

* update test

* fix typo

* remove unneeded print

* update rewoo agent

* add LLMTool

* update BaseAgent type

* add ReAct agent

* add ReAct agent

* minor update

* minor update

* minor update

* minor update

* update base reader with BaseComponent

* add splitter

* update agent and tool

* update vectorstores

* update load/save for indexing and retrieving pipeline

* update test_agent for more use-cases

* add missing dependency for test

* update test case for in memory vectorstore

* add TextSplitter to BaseComponent

* update type hint basetool

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin)
2023-10-04 14:27:44 +07:00
committed by GitHub
parent 49ed3f6994
commit 56bc41b673
13 changed files with 302 additions and 36 deletions

View File

@@ -1,6 +1,7 @@
from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from langchain.agents import Tool as LCTool
from pydantic import BaseModel
from kotaemon.base import BaseComponent
@@ -87,6 +88,10 @@ class BaseTool(BaseComponent):
)
return observation
def to_langchain_format(self) -> LCTool:
"""Convert this tool to Langchain format to use with its agent"""
return LCTool(name=self.name, description=self.description, func=self.run)
def run_raw(
self,
tool_input: Union[str, Dict],
@@ -122,6 +127,15 @@ class BaseTool(BaseComponent):
"""Tool does not support processing batch"""
return False
@classmethod
def from_langchain_format(cls, langchain_tool: LCTool) -> "BaseTool":
"""Wrapper for Langchain Tool"""
new_tool = BaseTool(
name=langchain_tool.name, description=langchain_tool.description
)
new_tool._run_tool = langchain_tool._run # type: ignore
return new_tool
class ComponentTool(BaseTool):
"""
@@ -130,6 +144,11 @@ class ComponentTool(BaseTool):
"""
component: BaseComponent
postprocessor: Optional[Callable] = None
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
return self.component(*args, **kwargs)
output = self.component(*args, **kwargs)
if self.postprocessor:
output = self.postprocessor(output)
return output