[AUR-363, AUR-433, AUR-434] Add Base Tool interface with Wikipedia/Google tools (#30)
* add base Tool * minor update test_tool * update test dependency * update test dependency * Fix namespace conflict * update test --------- Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
parent
317323c0e5
commit
f9fc02a32a
5
knowledgehub/pipelines/tools/__init__.py
Normal file
5
knowledgehub/pipelines/tools/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
from .base import BaseTool, ComponentTool
|
||||||
|
from .google import GoogleSearchTool
|
||||||
|
from .wikipedia import WikipediaTool
|
||||||
|
|
||||||
|
__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool"]
|
135
knowledgehub/pipelines/tools/base.py
Normal file
135
knowledgehub/pipelines/tools/base.py
Normal file
|
@ -0,0 +1,135 @@
|
||||||
|
from abc import abstractmethod
|
||||||
|
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
|
||||||
|
|
||||||
|
class ToolException(Exception):
|
||||||
|
"""An optional exception that tool throws when execution error occurs.
|
||||||
|
|
||||||
|
When this exception is thrown, the agent will not stop working,
|
||||||
|
but will handle the exception according to the handle_tool_error
|
||||||
|
variable of the tool, and the processing result will be returned
|
||||||
|
to the agent as observation, and printed in red on the console.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTool(BaseComponent):
|
||||||
|
name: str
|
||||||
|
"""The unique name of the tool that clearly communicates its purpose."""
|
||||||
|
description: str
|
||||||
|
"""Description used to tell the model how/when/why to use the tool.
|
||||||
|
You can provide few-shot examples as a part of the description. This will be
|
||||||
|
input to the prompt of LLM.
|
||||||
|
"""
|
||||||
|
args_schema: Optional[Type[BaseModel]] = None
|
||||||
|
"""Pydantic model class to validate and parse the tool's input arguments."""
|
||||||
|
verbose: bool = False
|
||||||
|
"""Whether to log the tool's progress."""
|
||||||
|
handle_tool_error: Optional[
|
||||||
|
Union[bool, str, Callable[[ToolException], str]]
|
||||||
|
] = False
|
||||||
|
"""Handle the content of the ToolException thrown."""
|
||||||
|
|
||||||
|
def _parse_input(
|
||||||
|
self,
|
||||||
|
tool_input: Union[str, Dict],
|
||||||
|
) -> Union[str, Dict[str, Any]]:
|
||||||
|
"""Convert tool input to pydantic model."""
|
||||||
|
args_schema = self.args_schema
|
||||||
|
if isinstance(tool_input, str):
|
||||||
|
if args_schema is not None:
|
||||||
|
key_ = next(iter(args_schema.__fields__.keys()))
|
||||||
|
args_schema.validate({key_: tool_input})
|
||||||
|
return tool_input
|
||||||
|
else:
|
||||||
|
if args_schema is not None:
|
||||||
|
result = args_schema.parse_obj(tool_input)
|
||||||
|
return {k: v for k, v in result.dict().items() if k in tool_input}
|
||||||
|
return tool_input
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _run_tool(
|
||||||
|
self,
|
||||||
|
*args: Any,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Call tool."""
|
||||||
|
|
||||||
|
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
|
||||||
|
# For backwards compatibility, if run_input is a string,
|
||||||
|
# pass as a positional argument.
|
||||||
|
if isinstance(tool_input, str):
|
||||||
|
return (tool_input,), {}
|
||||||
|
else:
|
||||||
|
return (), tool_input
|
||||||
|
|
||||||
|
def _handle_tool_error(self, e: ToolException) -> Any:
|
||||||
|
"""Handle the content of the ToolException thrown."""
|
||||||
|
observation = None
|
||||||
|
if not self.handle_tool_error:
|
||||||
|
raise e
|
||||||
|
elif isinstance(self.handle_tool_error, bool):
|
||||||
|
if e.args:
|
||||||
|
observation = e.args[0]
|
||||||
|
else:
|
||||||
|
observation = "Tool execution error"
|
||||||
|
elif isinstance(self.handle_tool_error, str):
|
||||||
|
observation = self.handle_tool_error
|
||||||
|
elif callable(self.handle_tool_error):
|
||||||
|
observation = self.handle_tool_error(e)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
|
||||||
|
f"or callable. Received: {self.handle_tool_error}"
|
||||||
|
)
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def run_raw(
|
||||||
|
self,
|
||||||
|
tool_input: Union[str, Dict],
|
||||||
|
verbose: Optional[bool] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Any:
|
||||||
|
"""Run the tool."""
|
||||||
|
parsed_input = self._parse_input(tool_input)
|
||||||
|
# TODO (verbose_): Add logging
|
||||||
|
try:
|
||||||
|
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
|
||||||
|
observation = self._run_tool(*tool_args, **tool_kwargs)
|
||||||
|
except ToolException as e:
|
||||||
|
observation = self._handle_tool_error(e)
|
||||||
|
return observation
|
||||||
|
else:
|
||||||
|
return observation
|
||||||
|
|
||||||
|
def run_document(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run_batch_raw(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def run_batch_document(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_document(self, *args, **kwargs) -> bool:
|
||||||
|
"""Tool does not support processing document"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_batch(self, *args, **kwargs) -> bool:
|
||||||
|
"""Tool does not support processing batch"""
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class ComponentTool(BaseTool):
|
||||||
|
"""
|
||||||
|
A Tool based on another pipeline / BaseComponent to be used
|
||||||
|
as its main entry point
|
||||||
|
"""
|
||||||
|
|
||||||
|
component: BaseComponent
|
||||||
|
|
||||||
|
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||||||
|
return self.component(*args, **kwargs)
|
35
knowledgehub/pipelines/tools/google.py
Normal file
35
knowledgehub/pipelines/tools/google.py
Normal file
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import AnyStr, Optional, Type
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSearchArgs(BaseModel):
|
||||||
|
query: str = Field(..., description="a search query")
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSearchTool(BaseTool):
|
||||||
|
name = "google_search"
|
||||||
|
description = (
|
||||||
|
"A search engine retrieving top search results as snippets from Google. "
|
||||||
|
"Input should be a search query."
|
||||||
|
)
|
||||||
|
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
|
||||||
|
|
||||||
|
def _run_tool(self, query: AnyStr) -> str:
|
||||||
|
try:
|
||||||
|
from googlesearch import search
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"install googlesearch using `pip3 install googlesearch-python` to "
|
||||||
|
"use this tool"
|
||||||
|
)
|
||||||
|
output = ""
|
||||||
|
search_results = search(query, advanced=True)
|
||||||
|
if search_results:
|
||||||
|
output = "\n".join(
|
||||||
|
"{} {}".format(item.title, item.description) for item in search_results
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
65
knowledgehub/pipelines/tools/wikipedia.py
Normal file
65
knowledgehub/pipelines/tools/wikipedia.py
Normal file
|
@ -0,0 +1,65 @@
|
||||||
|
from typing import Any, AnyStr, Optional, Type, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from kotaemon.documents.base import Document
|
||||||
|
|
||||||
|
from .base import BaseTool
|
||||||
|
|
||||||
|
|
||||||
|
class Wiki:
|
||||||
|
"""Wrapper around wikipedia API."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
"""Check that wikipedia package is installed."""
|
||||||
|
try:
|
||||||
|
import wikipedia # noqa: F401
|
||||||
|
except ImportError:
|
||||||
|
raise ValueError(
|
||||||
|
"Could not import wikipedia python package. "
|
||||||
|
"Please install it with `pip install wikipedia`."
|
||||||
|
)
|
||||||
|
|
||||||
|
def search(self, search: str) -> Union[str, Document]:
|
||||||
|
"""Try to search for wiki page.
|
||||||
|
|
||||||
|
If page exists, return the page summary, and a PageWithLookups object.
|
||||||
|
If page does not exist, return similar entries.
|
||||||
|
"""
|
||||||
|
import wikipedia
|
||||||
|
|
||||||
|
try:
|
||||||
|
page_content = wikipedia.page(search).content
|
||||||
|
url = wikipedia.page(search).url
|
||||||
|
result: Union[str, Document] = Document(
|
||||||
|
text=page_content, metadata={"page": url}
|
||||||
|
)
|
||||||
|
except wikipedia.PageError:
|
||||||
|
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
|
||||||
|
except wikipedia.DisambiguationError:
|
||||||
|
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
class WikipediaArgs(BaseModel):
|
||||||
|
query: str = Field(..., description="a search query as input to wkipedia")
|
||||||
|
|
||||||
|
|
||||||
|
class WikipediaTool(BaseTool):
|
||||||
|
"""Tool that adds the capability to query the Wikipedia API."""
|
||||||
|
|
||||||
|
name = "wikipedia"
|
||||||
|
description = (
|
||||||
|
"Search engine from Wikipedia, retrieving relevant wiki page. "
|
||||||
|
"Useful when you need to get holistic knowledge about people, "
|
||||||
|
"places, companies, historical events, or other subjects."
|
||||||
|
)
|
||||||
|
args_schema: Optional[Type[BaseModel]] = WikipediaArgs
|
||||||
|
doc_store: Any = None
|
||||||
|
|
||||||
|
def _run_tool(self, query: AnyStr) -> AnyStr:
|
||||||
|
if not self.doc_store:
|
||||||
|
self.doc_store = Wiki()
|
||||||
|
tool = self.doc_store
|
||||||
|
evidence = tool.search(query)
|
||||||
|
return evidence
|
2
setup.py
2
setup.py
|
@ -54,6 +54,8 @@ setuptools.setup(
|
||||||
# optional dependency needed for test
|
# optional dependency needed for test
|
||||||
"openai",
|
"openai",
|
||||||
"chromadb",
|
"chromadb",
|
||||||
|
"wikipedia",
|
||||||
|
"googlesearch-python",
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
|
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
|
||||||
|
|
70
tests/test_tools.py
Normal file
70
tests/test_tools.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.api_resources.embedding import Embedding
|
||||||
|
|
||||||
|
from kotaemon.docstores import InMemoryDocumentStore
|
||||||
|
from kotaemon.documents.base import Document
|
||||||
|
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||||
|
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||||
|
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||||
|
from kotaemon.pipelines.tools import ComponentTool, GoogleSearchTool, WikipediaTool
|
||||||
|
from kotaemon.vectorstores import ChromaVectorStore
|
||||||
|
|
||||||
|
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
|
||||||
|
openai_embedding = json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function")
|
||||||
|
def mock_openai_embedding(monkeypatch):
|
||||||
|
monkeypatch.setattr(Embedding, "create", lambda *args, **kwargs: openai_embedding)
|
||||||
|
|
||||||
|
|
||||||
|
def test_google_tool():
|
||||||
|
tool = GoogleSearchTool()
|
||||||
|
assert tool.name
|
||||||
|
assert tool.description
|
||||||
|
output = tool("What is Cinnamon AI")
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
def test_wikipedia_tool():
|
||||||
|
tool = WikipediaTool()
|
||||||
|
assert tool.name
|
||||||
|
assert tool.description
|
||||||
|
output = tool("Cinnamon")
|
||||||
|
assert output
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_tool(mock_openai_embedding, tmp_path):
|
||||||
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
|
embedding = AzureOpenAIEmbeddings(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
deployment="embedding-deployment",
|
||||||
|
openai_api_base="https://test.openai.azure.com/",
|
||||||
|
openai_api_key="some-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
index_pipeline = IndexVectorStoreFromDocumentPipeline(
|
||||||
|
vector_store=db, embedding=embedding, doc_store=doc_store
|
||||||
|
)
|
||||||
|
retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||||
|
vector_store=db, doc_store=doc_store, embedding=embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
index_tool = ComponentTool(
|
||||||
|
name="index_document",
|
||||||
|
description="A tool to use to index a document to be searched later",
|
||||||
|
component=index_pipeline,
|
||||||
|
)
|
||||||
|
output = index_tool({"text": Document(text="Cinnamon AI")})
|
||||||
|
|
||||||
|
retrieval_tool = ComponentTool(
|
||||||
|
name="search_document",
|
||||||
|
description="A tool to use to search a document in a vectorstore",
|
||||||
|
component=retrieval_pipeline,
|
||||||
|
)
|
||||||
|
output = retrieval_tool("Cinnamon AI")
|
||||||
|
assert output
|
Loading…
Reference in New Issue
Block a user