[AUR-363, AUR-433, AUR-434] Add Base Tool interface with Wikipedia/Google tools (#30)

* add base Tool

* minor update test_tool

* update test dependency

* update test dependency

* Fix namespace conflict

* update test

---------

Co-authored-by: trducng <trungduc1992@gmail.com>
This commit is contained in:
Tuan Anh Nguyen Dang (Tadashi_Cin) 2023-09-29 10:18:49 +07:00 committed by GitHub
parent 317323c0e5
commit f9fc02a32a
6 changed files with 312 additions and 0 deletions

View File

@ -0,0 +1,5 @@
from .base import BaseTool, ComponentTool
from .google import GoogleSearchTool
from .wikipedia import WikipediaTool
__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool"]

View File

@ -0,0 +1,135 @@
from abc import abstractmethod
from typing import Any, Callable, Dict, Optional, Tuple, Type, Union
from pydantic import BaseModel
from kotaemon.base import BaseComponent
class ToolException(Exception):
"""An optional exception that tool throws when execution error occurs.
When this exception is thrown, the agent will not stop working,
but will handle the exception according to the handle_tool_error
variable of the tool, and the processing result will be returned
to the agent as observation, and printed in red on the console.
"""
class BaseTool(BaseComponent):
name: str
"""The unique name of the tool that clearly communicates its purpose."""
description: str
"""Description used to tell the model how/when/why to use the tool.
You can provide few-shot examples as a part of the description. This will be
input to the prompt of LLM.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
verbose: bool = False
"""Whether to log the tool's progress."""
handle_tool_error: Optional[
Union[bool, str, Callable[[ToolException], str]]
] = False
"""Handle the content of the ToolException thrown."""
def _parse_input(
self,
tool_input: Union[str, Dict],
) -> Union[str, Dict[str, Any]]:
"""Convert tool input to pydantic model."""
args_schema = self.args_schema
if isinstance(tool_input, str):
if args_schema is not None:
key_ = next(iter(args_schema.__fields__.keys()))
args_schema.validate({key_: tool_input})
return tool_input
else:
if args_schema is not None:
result = args_schema.parse_obj(tool_input)
return {k: v for k, v in result.dict().items() if k in tool_input}
return tool_input
@abstractmethod
def _run_tool(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Call tool."""
def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]:
# For backwards compatibility, if run_input is a string,
# pass as a positional argument.
if isinstance(tool_input, str):
return (tool_input,), {}
else:
return (), tool_input
def _handle_tool_error(self, e: ToolException) -> Any:
"""Handle the content of the ToolException thrown."""
observation = None
if not self.handle_tool_error:
raise e
elif isinstance(self.handle_tool_error, bool):
if e.args:
observation = e.args[0]
else:
observation = "Tool execution error"
elif isinstance(self.handle_tool_error, str):
observation = self.handle_tool_error
elif callable(self.handle_tool_error):
observation = self.handle_tool_error(e)
else:
raise ValueError(
f"Got unexpected type of `handle_tool_error`. Expected bool, str "
f"or callable. Received: {self.handle_tool_error}"
)
return observation
def run_raw(
self,
tool_input: Union[str, Dict],
verbose: Optional[bool] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
parsed_input = self._parse_input(tool_input)
# TODO (verbose_): Add logging
try:
tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input)
observation = self._run_tool(*tool_args, **tool_kwargs)
except ToolException as e:
observation = self._handle_tool_error(e)
return observation
else:
return observation
def run_document(self, *args, **kwargs):
pass
def run_batch_raw(self, *args, **kwargs):
pass
def run_batch_document(self, *args, **kwargs):
pass
def is_document(self, *args, **kwargs) -> bool:
"""Tool does not support processing document"""
return False
def is_batch(self, *args, **kwargs) -> bool:
"""Tool does not support processing batch"""
return False
class ComponentTool(BaseTool):
"""
A Tool based on another pipeline / BaseComponent to be used
as its main entry point
"""
component: BaseComponent
def _run_tool(self, *args: Any, **kwargs: Any) -> Any:
return self.component(*args, **kwargs)

View File

@ -0,0 +1,35 @@
from typing import AnyStr, Optional, Type
from pydantic import BaseModel, Field
from .base import BaseTool
class GoogleSearchArgs(BaseModel):
query: str = Field(..., description="a search query")
class GoogleSearchTool(BaseTool):
name = "google_search"
description = (
"A search engine retrieving top search results as snippets from Google. "
"Input should be a search query."
)
args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs
def _run_tool(self, query: AnyStr) -> str:
try:
from googlesearch import search
except ImportError:
raise ImportError(
"install googlesearch using `pip3 install googlesearch-python` to "
"use this tool"
)
output = ""
search_results = search(query, advanced=True)
if search_results:
output = "\n".join(
"{} {}".format(item.title, item.description) for item in search_results
)
return output

View File

@ -0,0 +1,65 @@
from typing import Any, AnyStr, Optional, Type, Union
from pydantic import BaseModel, Field
from kotaemon.documents.base import Document
from .base import BaseTool
class Wiki:
"""Wrapper around wikipedia API."""
def __init__(self) -> None:
"""Check that wikipedia package is installed."""
try:
import wikipedia # noqa: F401
except ImportError:
raise ValueError(
"Could not import wikipedia python package. "
"Please install it with `pip install wikipedia`."
)
def search(self, search: str) -> Union[str, Document]:
"""Try to search for wiki page.
If page exists, return the page summary, and a PageWithLookups object.
If page does not exist, return similar entries.
"""
import wikipedia
try:
page_content = wikipedia.page(search).content
url = wikipedia.page(search).url
result: Union[str, Document] = Document(
text=page_content, metadata={"page": url}
)
except wikipedia.PageError:
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
except wikipedia.DisambiguationError:
result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}"
return result
class WikipediaArgs(BaseModel):
query: str = Field(..., description="a search query as input to wkipedia")
class WikipediaTool(BaseTool):
"""Tool that adds the capability to query the Wikipedia API."""
name = "wikipedia"
description = (
"Search engine from Wikipedia, retrieving relevant wiki page. "
"Useful when you need to get holistic knowledge about people, "
"places, companies, historical events, or other subjects."
)
args_schema: Optional[Type[BaseModel]] = WikipediaArgs
doc_store: Any = None
def _run_tool(self, query: AnyStr) -> AnyStr:
if not self.doc_store:
self.doc_store = Wiki()
tool = self.doc_store
evidence = tool.search(query)
return evidence

View File

@ -54,6 +54,8 @@ setuptools.setup(
# optional dependency needed for test
"openai",
"chromadb",
"wikipedia",
"googlesearch-python",
],
},
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},

70
tests/test_tools.py Normal file
View File

@ -0,0 +1,70 @@
import json
from pathlib import Path
import pytest
from openai.api_resources.embedding import Embedding
from kotaemon.docstores import InMemoryDocumentStore
from kotaemon.documents.base import Document
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.pipelines.tools import ComponentTool, GoogleSearchTool, WikipediaTool
from kotaemon.vectorstores import ChromaVectorStore
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
openai_embedding = json.load(f)
@pytest.fixture(scope="function")
def mock_openai_embedding(monkeypatch):
monkeypatch.setattr(Embedding, "create", lambda *args, **kwargs: openai_embedding)
def test_google_tool():
tool = GoogleSearchTool()
assert tool.name
assert tool.description
output = tool("What is Cinnamon AI")
assert output
def test_wikipedia_tool():
tool = WikipediaTool()
assert tool.name
assert tool.description
output = tool("Cinnamon")
assert output
def test_pipeline_tool(mock_openai_embedding, tmp_path):
db = ChromaVectorStore(path=str(tmp_path))
doc_store = InMemoryDocumentStore()
embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
)
index_pipeline = IndexVectorStoreFromDocumentPipeline(
vector_store=db, embedding=embedding, doc_store=doc_store
)
retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=db, doc_store=doc_store, embedding=embedding
)
index_tool = ComponentTool(
name="index_document",
description="A tool to use to index a document to be searched later",
component=index_pipeline,
)
output = index_tool({"text": Document(text="Cinnamon AI")})
retrieval_tool = ComponentTool(
name="search_document",
description="A tool to use to search a document in a vectorstore",
component=retrieval_pipeline,
)
output = retrieval_tool("Cinnamon AI")
assert output