diff --git a/knowledgehub/pipelines/tools/__init__.py b/knowledgehub/pipelines/tools/__init__.py new file mode 100644 index 0000000..7302ed6 --- /dev/null +++ b/knowledgehub/pipelines/tools/__init__.py @@ -0,0 +1,5 @@ +from .base import BaseTool, ComponentTool +from .google import GoogleSearchTool +from .wikipedia import WikipediaTool + +__all__ = ["BaseTool", "ComponentTool", "GoogleSearchTool", "WikipediaTool"] diff --git a/knowledgehub/pipelines/tools/base.py b/knowledgehub/pipelines/tools/base.py new file mode 100644 index 0000000..4e95ddb --- /dev/null +++ b/knowledgehub/pipelines/tools/base.py @@ -0,0 +1,135 @@ +from abc import abstractmethod +from typing import Any, Callable, Dict, Optional, Tuple, Type, Union + +from pydantic import BaseModel + +from kotaemon.base import BaseComponent + + +class ToolException(Exception): + """An optional exception that tool throws when execution error occurs. + + When this exception is thrown, the agent will not stop working, + but will handle the exception according to the handle_tool_error + variable of the tool, and the processing result will be returned + to the agent as observation, and printed in red on the console. + """ + + +class BaseTool(BaseComponent): + name: str + """The unique name of the tool that clearly communicates its purpose.""" + description: str + """Description used to tell the model how/when/why to use the tool. + You can provide few-shot examples as a part of the description. This will be + input to the prompt of LLM. + """ + args_schema: Optional[Type[BaseModel]] = None + """Pydantic model class to validate and parse the tool's input arguments.""" + verbose: bool = False + """Whether to log the tool's progress.""" + handle_tool_error: Optional[ + Union[bool, str, Callable[[ToolException], str]] + ] = False + """Handle the content of the ToolException thrown.""" + + def _parse_input( + self, + tool_input: Union[str, Dict], + ) -> Union[str, Dict[str, Any]]: + """Convert tool input to pydantic model.""" + args_schema = self.args_schema + if isinstance(tool_input, str): + if args_schema is not None: + key_ = next(iter(args_schema.__fields__.keys())) + args_schema.validate({key_: tool_input}) + return tool_input + else: + if args_schema is not None: + result = args_schema.parse_obj(tool_input) + return {k: v for k, v in result.dict().items() if k in tool_input} + return tool_input + + @abstractmethod + def _run_tool( + self, + *args: Any, + **kwargs: Any, + ) -> Any: + """Call tool.""" + + def _to_args_and_kwargs(self, tool_input: Union[str, Dict]) -> Tuple[Tuple, Dict]: + # For backwards compatibility, if run_input is a string, + # pass as a positional argument. + if isinstance(tool_input, str): + return (tool_input,), {} + else: + return (), tool_input + + def _handle_tool_error(self, e: ToolException) -> Any: + """Handle the content of the ToolException thrown.""" + observation = None + if not self.handle_tool_error: + raise e + elif isinstance(self.handle_tool_error, bool): + if e.args: + observation = e.args[0] + else: + observation = "Tool execution error" + elif isinstance(self.handle_tool_error, str): + observation = self.handle_tool_error + elif callable(self.handle_tool_error): + observation = self.handle_tool_error(e) + else: + raise ValueError( + f"Got unexpected type of `handle_tool_error`. Expected bool, str " + f"or callable. Received: {self.handle_tool_error}" + ) + return observation + + def run_raw( + self, + tool_input: Union[str, Dict], + verbose: Optional[bool] = None, + **kwargs: Any, + ) -> Any: + """Run the tool.""" + parsed_input = self._parse_input(tool_input) + # TODO (verbose_): Add logging + try: + tool_args, tool_kwargs = self._to_args_and_kwargs(parsed_input) + observation = self._run_tool(*tool_args, **tool_kwargs) + except ToolException as e: + observation = self._handle_tool_error(e) + return observation + else: + return observation + + def run_document(self, *args, **kwargs): + pass + + def run_batch_raw(self, *args, **kwargs): + pass + + def run_batch_document(self, *args, **kwargs): + pass + + def is_document(self, *args, **kwargs) -> bool: + """Tool does not support processing document""" + return False + + def is_batch(self, *args, **kwargs) -> bool: + """Tool does not support processing batch""" + return False + + +class ComponentTool(BaseTool): + """ + A Tool based on another pipeline / BaseComponent to be used + as its main entry point + """ + + component: BaseComponent + + def _run_tool(self, *args: Any, **kwargs: Any) -> Any: + return self.component(*args, **kwargs) diff --git a/knowledgehub/pipelines/tools/google.py b/knowledgehub/pipelines/tools/google.py new file mode 100644 index 0000000..feb28c6 --- /dev/null +++ b/knowledgehub/pipelines/tools/google.py @@ -0,0 +1,35 @@ +from typing import AnyStr, Optional, Type + +from pydantic import BaseModel, Field + +from .base import BaseTool + + +class GoogleSearchArgs(BaseModel): + query: str = Field(..., description="a search query") + + +class GoogleSearchTool(BaseTool): + name = "google_search" + description = ( + "A search engine retrieving top search results as snippets from Google. " + "Input should be a search query." + ) + args_schema: Optional[Type[BaseModel]] = GoogleSearchArgs + + def _run_tool(self, query: AnyStr) -> str: + try: + from googlesearch import search + except ImportError: + raise ImportError( + "install googlesearch using `pip3 install googlesearch-python` to " + "use this tool" + ) + output = "" + search_results = search(query, advanced=True) + if search_results: + output = "\n".join( + "{} {}".format(item.title, item.description) for item in search_results + ) + + return output diff --git a/knowledgehub/pipelines/tools/wikipedia.py b/knowledgehub/pipelines/tools/wikipedia.py new file mode 100644 index 0000000..ef6b8d2 --- /dev/null +++ b/knowledgehub/pipelines/tools/wikipedia.py @@ -0,0 +1,65 @@ +from typing import Any, AnyStr, Optional, Type, Union + +from pydantic import BaseModel, Field + +from kotaemon.documents.base import Document + +from .base import BaseTool + + +class Wiki: + """Wrapper around wikipedia API.""" + + def __init__(self) -> None: + """Check that wikipedia package is installed.""" + try: + import wikipedia # noqa: F401 + except ImportError: + raise ValueError( + "Could not import wikipedia python package. " + "Please install it with `pip install wikipedia`." + ) + + def search(self, search: str) -> Union[str, Document]: + """Try to search for wiki page. + + If page exists, return the page summary, and a PageWithLookups object. + If page does not exist, return similar entries. + """ + import wikipedia + + try: + page_content = wikipedia.page(search).content + url = wikipedia.page(search).url + result: Union[str, Document] = Document( + text=page_content, metadata={"page": url} + ) + except wikipedia.PageError: + result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}" + except wikipedia.DisambiguationError: + result = f"Could not find [{search}]. Similar: {wikipedia.search(search)}" + return result + + +class WikipediaArgs(BaseModel): + query: str = Field(..., description="a search query as input to wkipedia") + + +class WikipediaTool(BaseTool): + """Tool that adds the capability to query the Wikipedia API.""" + + name = "wikipedia" + description = ( + "Search engine from Wikipedia, retrieving relevant wiki page. " + "Useful when you need to get holistic knowledge about people, " + "places, companies, historical events, or other subjects." + ) + args_schema: Optional[Type[BaseModel]] = WikipediaArgs + doc_store: Any = None + + def _run_tool(self, query: AnyStr) -> AnyStr: + if not self.doc_store: + self.doc_store = Wiki() + tool = self.doc_store + evidence = tool.search(query) + return evidence diff --git a/setup.py b/setup.py index b0aa963..152d537 100644 --- a/setup.py +++ b/setup.py @@ -54,6 +54,8 @@ setuptools.setup( # optional dependency needed for test "openai", "chromadb", + "wikipedia", + "googlesearch-python", ], }, entry_points={"console_scripts": ["kh=kotaemon.cli:main"]}, diff --git a/tests/test_tools.py b/tests/test_tools.py new file mode 100644 index 0000000..c010862 --- /dev/null +++ b/tests/test_tools.py @@ -0,0 +1,70 @@ +import json +from pathlib import Path + +import pytest +from openai.api_resources.embedding import Embedding + +from kotaemon.docstores import InMemoryDocumentStore +from kotaemon.documents.base import Document +from kotaemon.embeddings.openai import AzureOpenAIEmbeddings +from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline +from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline +from kotaemon.pipelines.tools import ComponentTool, GoogleSearchTool, WikipediaTool +from kotaemon.vectorstores import ChromaVectorStore + +with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f: + openai_embedding = json.load(f) + + +@pytest.fixture(scope="function") +def mock_openai_embedding(monkeypatch): + monkeypatch.setattr(Embedding, "create", lambda *args, **kwargs: openai_embedding) + + +def test_google_tool(): + tool = GoogleSearchTool() + assert tool.name + assert tool.description + output = tool("What is Cinnamon AI") + assert output + + +def test_wikipedia_tool(): + tool = WikipediaTool() + assert tool.name + assert tool.description + output = tool("Cinnamon") + assert output + + +def test_pipeline_tool(mock_openai_embedding, tmp_path): + db = ChromaVectorStore(path=str(tmp_path)) + doc_store = InMemoryDocumentStore() + embedding = AzureOpenAIEmbeddings( + model="text-embedding-ada-002", + deployment="embedding-deployment", + openai_api_base="https://test.openai.azure.com/", + openai_api_key="some-key", + ) + + index_pipeline = IndexVectorStoreFromDocumentPipeline( + vector_store=db, embedding=embedding, doc_store=doc_store + ) + retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline( + vector_store=db, doc_store=doc_store, embedding=embedding + ) + + index_tool = ComponentTool( + name="index_document", + description="A tool to use to index a document to be searched later", + component=index_pipeline, + ) + output = index_tool({"text": Document(text="Cinnamon AI")}) + + retrieval_tool = ComponentTool( + name="search_document", + description="A tool to use to search a document in a vectorstore", + component=retrieval_pipeline, + ) + output = retrieval_tool("Cinnamon AI") + assert output