diff --git a/knowledgehub/agents/base.py b/knowledgehub/agents/base.py index acfddcb..131466c 100644 --- a/knowledgehub/agents/base.py +++ b/knowledgehub/agents/base.py @@ -1,8 +1,6 @@ from typing import Optional, Union -from theflow import Node, Param - -from kotaemon.base import BaseComponent +from kotaemon.base import BaseComponent, Node, Param from kotaemon.llms import BaseLLM, PromptTemplate from .io import AgentOutput, AgentType diff --git a/knowledgehub/agents/react/agent.py b/knowledgehub/agents/react/agent.py index abb70bb..6b7fe97 100644 --- a/knowledgehub/agents/react/agent.py +++ b/knowledgehub/agents/react/agent.py @@ -2,11 +2,10 @@ import logging import re from typing import Optional -from theflow import Param - from kotaemon.agents.base import BaseAgent, BaseLLM from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType from kotaemon.agents.tools import BaseTool +from kotaemon.base import Param from kotaemon.llms import PromptTemplate FINAL_ANSWER_ACTION = "Final Answer:" diff --git a/knowledgehub/agents/rewoo/agent.py b/knowledgehub/agents/rewoo/agent.py index 5c6fba3..81f8b8a 100644 --- a/knowledgehub/agents/rewoo/agent.py +++ b/knowledgehub/agents/rewoo/agent.py @@ -3,12 +3,11 @@ import re from concurrent.futures import ThreadPoolExecutor from typing import Any -from theflow import Node, Param - from kotaemon.agents.base import BaseAgent from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad from kotaemon.agents.tools import BaseTool from kotaemon.agents.utils import get_plugin_response_content +from kotaemon.base import Node, Param from kotaemon.indices.qa import CitationPipeline from kotaemon.llms import BaseLLM, PromptTemplate diff --git a/knowledgehub/base/__init__.py b/knowledgehub/base/__init__.py index 11e9b76..52e036f 100644 --- a/knowledgehub/base/__init__.py +++ b/knowledgehub/base/__init__.py @@ -1,4 +1,4 @@ -from .component import BaseComponent +from .component import BaseComponent, Node, Param, lazy from .schema import ( AIMessage, BaseMessage, @@ -22,4 +22,7 @@ __all__ = [ "RetrievedDocument", "LLMInterface", "ExtractorOutput", + "Param", + "Node", + "lazy", ] diff --git a/knowledgehub/base/component.py b/knowledgehub/base/component.py index 71da362..bf1851f 100644 --- a/knowledgehub/base/component.py +++ b/knowledgehub/base/component.py @@ -1,6 +1,6 @@ from abc import abstractmethod -from theflow.base import Function +from theflow import Function, Node, Param, lazy from kotaemon.base.schema import Document @@ -35,3 +35,6 @@ class BaseComponent(Function): def run(self, *args, **kwargs) -> Document | list[Document] | None: """Run the component.""" ... + + +__all__ = ["BaseComponent", "Param", "Node", "lazy"] diff --git a/knowledgehub/indices/ingests/files.py b/knowledgehub/indices/ingests/files.py index 83ea8b9..1d7db35 100644 --- a/knowledgehub/indices/ingests/files.py +++ b/knowledgehub/indices/ingests/files.py @@ -1,9 +1,8 @@ from pathlib import Path from llama_index.readers.base import BaseReader -from theflow import Param -from kotaemon.base import BaseComponent, Document +from kotaemon.base import BaseComponent, Document, Param from kotaemon.indices.extractors import BaseDocParser from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from kotaemon.loaders import ( diff --git a/knowledgehub/llms/branching.py b/knowledgehub/llms/branching.py index 07671f2..8f93caf 100644 --- a/knowledgehub/llms/branching.py +++ b/knowledgehub/llms/branching.py @@ -1,8 +1,7 @@ from typing import List, Optional -from theflow import Param +from kotaemon.base import BaseComponent, Document, Param -from ..base import BaseComponent, Document from .linear import GatedLinearPipeline diff --git a/knowledgehub/llms/cot.py b/knowledgehub/llms/cot.py index 29256b3..5786f41 100644 --- a/knowledgehub/llms/cot.py +++ b/knowledgehub/llms/cot.py @@ -1,9 +1,7 @@ from copy import deepcopy from typing import Callable, List -from theflow import Function, Node, Param - -from kotaemon.base import BaseComponent, Document +from kotaemon.base import BaseComponent, Document, Node, Param from .chats import AzureChatOpenAI from .completions import LLM @@ -74,7 +72,7 @@ class Thought(BaseComponent): ) ) llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt") - post_process: Function = Node( + post_process: BaseComponent = Node( help=( "The function post-processor that post-processes LLM output prediction ." "It should take a string as input (this is the LLM output text) and return " @@ -85,7 +83,7 @@ class Thought(BaseComponent): @Node.auto(depends_on="prompt") def prompt_template(self): """Automatically wrap around param prompt. Can ignore""" - return BasePromptComponent(self.prompt) + return BasePromptComponent(template=self.prompt) def run(self, **kwargs) -> Document: """Run the chain of thought""" diff --git a/knowledgehub/llms/prompts/base.py b/knowledgehub/llms/prompts/base.py index d0d8d8c..564279d 100644 --- a/knowledgehub/llms/prompts/base.py +++ b/knowledgehub/llms/prompts/base.py @@ -1,6 +1,7 @@ from typing import Callable, Union -from ...base import BaseComponent, Document +from kotaemon.base import BaseComponent, Document + from .template import PromptTemplate @@ -16,6 +17,7 @@ class BasePromptComponent(BaseComponent): class Config: middleware_switches = {"theflow.middleware.CachingMiddleware": False} + allow_extra = True def __init__(self, template: Union[str, PromptTemplate], **kwargs): super().__init__() diff --git a/knowledgehub/parsers/regex_extractor.py b/knowledgehub/parsers/regex_extractor.py index a269fb0..90ad365 100644 --- a/knowledgehub/parsers/regex_extractor.py +++ b/knowledgehub/parsers/regex_extractor.py @@ -3,10 +3,7 @@ from __future__ import annotations import re from typing import Callable -from theflow import Param - -from kotaemon.base import BaseComponent, Document -from kotaemon.base.schema import ExtractorOutput +from kotaemon.base import BaseComponent, Document, ExtractorOutput, Param class RegexExtractor(BaseComponent): diff --git a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py index 32735de..1739ca8 100644 --- a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py +++ b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py @@ -1,10 +1,7 @@ import os from typing import List -from theflow import Node, Param -from theflow.utils.modules import ObjectInitDeclaration as _ - -from kotaemon.base import BaseComponent, Document, LLMInterface +from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, lazy from kotaemon.contribs.promptui.logs import ResultLog from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices import VectorIndexing, VectorRetrieval @@ -48,8 +45,8 @@ class QuestionAnsweringPipeline(BaseComponent): retrieving_pipeline: VectorRetrieval = Node( VectorRetrieval.withx( - vector_store=_(ChromaVectorStore).withx(path="./tmp"), - doc_store=_(SimpleFileDocumentStore).withx(path="docstore.json"), + vector_store=lazy(ChromaVectorStore).withx(path="./tmp"), + doc_store=lazy(SimpleFileDocumentStore).withx(path="docstore.json"), embedding=AzureOpenAIEmbeddings.withx( model="text-embedding-ada-002", deployment="dummy-q2-text-embedding", @@ -78,11 +75,11 @@ class QuestionAnsweringPipeline(BaseComponent): class IndexingPipeline(VectorIndexing): vector_store: ChromaVectorStore = Param( - _(ChromaVectorStore).withx(path="./tmp"), + lazy(ChromaVectorStore).withx(path="./tmp"), ignore_ui=True, ) doc_store: SimpleFileDocumentStore = Param( - _(SimpleFileDocumentStore).withx(path="docstore.json"), + lazy(SimpleFileDocumentStore).withx(path="docstore.json"), ignore_ui=True, ) embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( diff --git a/tests/simple_pipeline.py b/tests/simple_pipeline.py index 8394f7b..e3c6cf2 100644 --- a/tests/simple_pipeline.py +++ b/tests/simple_pipeline.py @@ -1,9 +1,7 @@ import tempfile from typing import List -from theflow.utils.modules import ObjectInitDeclaration as _ - -from kotaemon.base import BaseComponent, LLMInterface +from kotaemon.base import BaseComponent, LLMInterface, lazy from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices import VectorRetrieval from kotaemon.llms import AzureOpenAI @@ -21,7 +19,7 @@ class Pipeline(BaseComponent): ) retrieving_pipeline: VectorRetrieval = VectorRetrieval.withx( - vector_store=_(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())), + vector_store=lazy(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())), embedding=AzureOpenAIEmbeddings.withx( model="text-embedding-ada-002", deployment="embedding-deployment", diff --git a/tests/test_prompt.py b/tests/test_prompt.py index 2063a65..9cc72e0 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -22,21 +22,21 @@ def test_set_attributes(): def test_check_redundant_kwargs(): template = PromptTemplate("Hello, {name}!") - prompt = BasePromptComponent(template, name="Alice") + prompt = BasePromptComponent(template=template, name="Alice") with pytest.warns(UserWarning, match="Keys provided but not in template: age"): prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30) def test_check_unset_placeholders(): template = PromptTemplate("Hello, {name}! I'm {age} years old.") - prompt = BasePromptComponent(template, name="Alice") + prompt = BasePromptComponent(template=template, name="Alice") with pytest.raises(ValueError): prompt._BasePromptComponent__check_unset_placeholders() def test_validate_value_type(): template = PromptTemplate("Hello, {name}!") - prompt = BasePromptComponent(template) + prompt = BasePromptComponent(template=template) with pytest.raises(ValueError): prompt._BasePromptComponent__validate_value_type(name={}) @@ -58,6 +58,6 @@ def test_run(): def test_set_method(): template = PromptTemplate("Hello, {name}!") - prompt = BasePromptComponent(template) + prompt = BasePromptComponent(template=template) prompt.set(name="Alice") assert prompt.name == "Alice"