Change template to private attribute and simplify imports (#101)

---------

Co-authored-by: ian <ian@cinnamon.is>
This commit is contained in:
Duc Nguyen (john) 2023-12-08 18:10:34 +07:00 committed by GitHub
parent 1f927d3391
commit da0ac1d69f
13 changed files with 31 additions and 39 deletions

View File

@ -1,8 +1,6 @@
from typing import Optional, Union from typing import Optional, Union
from theflow import Node, Param from kotaemon.base import BaseComponent, Node, Param
from kotaemon.base import BaseComponent
from kotaemon.llms import BaseLLM, PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate
from .io import AgentOutput, AgentType from .io import AgentOutput, AgentType

View File

@ -2,11 +2,10 @@ import logging
import re import re
from typing import Optional from typing import Optional
from theflow import Param
from kotaemon.agents.base import BaseAgent, BaseLLM from kotaemon.agents.base import BaseAgent, BaseLLM
from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType from kotaemon.agents.io import AgentAction, AgentFinish, AgentOutput, AgentType
from kotaemon.agents.tools import BaseTool from kotaemon.agents.tools import BaseTool
from kotaemon.base import Param
from kotaemon.llms import PromptTemplate from kotaemon.llms import PromptTemplate
FINAL_ANSWER_ACTION = "Final Answer:" FINAL_ANSWER_ACTION = "Final Answer:"

View File

@ -3,12 +3,11 @@ import re
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Any from typing import Any
from theflow import Node, Param
from kotaemon.agents.base import BaseAgent from kotaemon.agents.base import BaseAgent
from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad from kotaemon.agents.io import AgentOutput, AgentType, BaseScratchPad
from kotaemon.agents.tools import BaseTool from kotaemon.agents.tools import BaseTool
from kotaemon.agents.utils import get_plugin_response_content from kotaemon.agents.utils import get_plugin_response_content
from kotaemon.base import Node, Param
from kotaemon.indices.qa import CitationPipeline from kotaemon.indices.qa import CitationPipeline
from kotaemon.llms import BaseLLM, PromptTemplate from kotaemon.llms import BaseLLM, PromptTemplate

View File

@ -1,4 +1,4 @@
from .component import BaseComponent from .component import BaseComponent, Node, Param, lazy
from .schema import ( from .schema import (
AIMessage, AIMessage,
BaseMessage, BaseMessage,
@ -22,4 +22,7 @@ __all__ = [
"RetrievedDocument", "RetrievedDocument",
"LLMInterface", "LLMInterface",
"ExtractorOutput", "ExtractorOutput",
"Param",
"Node",
"lazy",
] ]

View File

@ -1,6 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from theflow.base import Function from theflow import Function, Node, Param, lazy
from kotaemon.base.schema import Document from kotaemon.base.schema import Document
@ -35,3 +35,6 @@ class BaseComponent(Function):
def run(self, *args, **kwargs) -> Document | list[Document] | None: def run(self, *args, **kwargs) -> Document | list[Document] | None:
"""Run the component.""" """Run the component."""
... ...
__all__ = ["BaseComponent", "Param", "Node", "lazy"]

View File

@ -1,9 +1,8 @@
from pathlib import Path from pathlib import Path
from llama_index.readers.base import BaseReader from llama_index.readers.base import BaseReader
from theflow import Param
from kotaemon.base import BaseComponent, Document from kotaemon.base import BaseComponent, Document, Param
from kotaemon.indices.extractors import BaseDocParser from kotaemon.indices.extractors import BaseDocParser
from kotaemon.indices.splitters import BaseSplitter, TokenSplitter from kotaemon.indices.splitters import BaseSplitter, TokenSplitter
from kotaemon.loaders import ( from kotaemon.loaders import (

View File

@ -1,8 +1,7 @@
from typing import List, Optional from typing import List, Optional
from theflow import Param from kotaemon.base import BaseComponent, Document, Param
from ..base import BaseComponent, Document
from .linear import GatedLinearPipeline from .linear import GatedLinearPipeline

View File

@ -1,9 +1,7 @@
from copy import deepcopy from copy import deepcopy
from typing import Callable, List from typing import Callable, List
from theflow import Function, Node, Param from kotaemon.base import BaseComponent, Document, Node, Param
from kotaemon.base import BaseComponent, Document
from .chats import AzureChatOpenAI from .chats import AzureChatOpenAI
from .completions import LLM from .completions import LLM
@ -74,7 +72,7 @@ class Thought(BaseComponent):
) )
) )
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt") llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
post_process: Function = Node( post_process: BaseComponent = Node(
help=( help=(
"The function post-processor that post-processes LLM output prediction ." "The function post-processor that post-processes LLM output prediction ."
"It should take a string as input (this is the LLM output text) and return " "It should take a string as input (this is the LLM output text) and return "
@ -85,7 +83,7 @@ class Thought(BaseComponent):
@Node.auto(depends_on="prompt") @Node.auto(depends_on="prompt")
def prompt_template(self): def prompt_template(self):
"""Automatically wrap around param prompt. Can ignore""" """Automatically wrap around param prompt. Can ignore"""
return BasePromptComponent(self.prompt) return BasePromptComponent(template=self.prompt)
def run(self, **kwargs) -> Document: def run(self, **kwargs) -> Document:
"""Run the chain of thought""" """Run the chain of thought"""

View File

@ -1,6 +1,7 @@
from typing import Callable, Union from typing import Callable, Union
from ...base import BaseComponent, Document from kotaemon.base import BaseComponent, Document
from .template import PromptTemplate from .template import PromptTemplate
@ -16,6 +17,7 @@ class BasePromptComponent(BaseComponent):
class Config: class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False} middleware_switches = {"theflow.middleware.CachingMiddleware": False}
allow_extra = True
def __init__(self, template: Union[str, PromptTemplate], **kwargs): def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__() super().__init__()

View File

@ -3,10 +3,7 @@ from __future__ import annotations
import re import re
from typing import Callable from typing import Callable
from theflow import Param from kotaemon.base import BaseComponent, Document, ExtractorOutput, Param
from kotaemon.base import BaseComponent, Document
from kotaemon.base.schema import ExtractorOutput
class RegexExtractor(BaseComponent): class RegexExtractor(BaseComponent):

View File

@ -1,10 +1,7 @@
import os import os
from typing import List from typing import List
from theflow import Node, Param from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, lazy
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent, Document, LLMInterface
from kotaemon.contribs.promptui.logs import ResultLog from kotaemon.contribs.promptui.logs import ResultLog
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval from kotaemon.indices import VectorIndexing, VectorRetrieval
@ -48,8 +45,8 @@ class QuestionAnsweringPipeline(BaseComponent):
retrieving_pipeline: VectorRetrieval = Node( retrieving_pipeline: VectorRetrieval = Node(
VectorRetrieval.withx( VectorRetrieval.withx(
vector_store=_(ChromaVectorStore).withx(path="./tmp"), vector_store=lazy(ChromaVectorStore).withx(path="./tmp"),
doc_store=_(SimpleFileDocumentStore).withx(path="docstore.json"), doc_store=lazy(SimpleFileDocumentStore).withx(path="docstore.json"),
embedding=AzureOpenAIEmbeddings.withx( embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding", deployment="dummy-q2-text-embedding",
@ -78,11 +75,11 @@ class QuestionAnsweringPipeline(BaseComponent):
class IndexingPipeline(VectorIndexing): class IndexingPipeline(VectorIndexing):
vector_store: ChromaVectorStore = Param( vector_store: ChromaVectorStore = Param(
_(ChromaVectorStore).withx(path="./tmp"), lazy(ChromaVectorStore).withx(path="./tmp"),
ignore_ui=True, ignore_ui=True,
) )
doc_store: SimpleFileDocumentStore = Param( doc_store: SimpleFileDocumentStore = Param(
_(SimpleFileDocumentStore).withx(path="docstore.json"), lazy(SimpleFileDocumentStore).withx(path="docstore.json"),
ignore_ui=True, ignore_ui=True,
) )
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(

View File

@ -1,9 +1,7 @@
import tempfile import tempfile
from typing import List from typing import List
from theflow.utils.modules import ObjectInitDeclaration as _ from kotaemon.base import BaseComponent, LLMInterface, lazy
from kotaemon.base import BaseComponent, LLMInterface
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorRetrieval from kotaemon.indices import VectorRetrieval
from kotaemon.llms import AzureOpenAI from kotaemon.llms import AzureOpenAI
@ -21,7 +19,7 @@ class Pipeline(BaseComponent):
) )
retrieving_pipeline: VectorRetrieval = VectorRetrieval.withx( retrieving_pipeline: VectorRetrieval = VectorRetrieval.withx(
vector_store=_(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())), vector_store=lazy(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())),
embedding=AzureOpenAIEmbeddings.withx( embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="embedding-deployment", deployment="embedding-deployment",

View File

@ -22,21 +22,21 @@ def test_set_attributes():
def test_check_redundant_kwargs(): def test_check_redundant_kwargs():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePromptComponent(template, name="Alice") prompt = BasePromptComponent(template=template, name="Alice")
with pytest.warns(UserWarning, match="Keys provided but not in template: age"): with pytest.warns(UserWarning, match="Keys provided but not in template: age"):
prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30) prompt._BasePromptComponent__check_redundant_kwargs(name="Alice", age=30)
def test_check_unset_placeholders(): def test_check_unset_placeholders():
template = PromptTemplate("Hello, {name}! I'm {age} years old.") template = PromptTemplate("Hello, {name}! I'm {age} years old.")
prompt = BasePromptComponent(template, name="Alice") prompt = BasePromptComponent(template=template, name="Alice")
with pytest.raises(ValueError): with pytest.raises(ValueError):
prompt._BasePromptComponent__check_unset_placeholders() prompt._BasePromptComponent__check_unset_placeholders()
def test_validate_value_type(): def test_validate_value_type():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePromptComponent(template) prompt = BasePromptComponent(template=template)
with pytest.raises(ValueError): with pytest.raises(ValueError):
prompt._BasePromptComponent__validate_value_type(name={}) prompt._BasePromptComponent__validate_value_type(name={})
@ -58,6 +58,6 @@ def test_run():
def test_set_method(): def test_set_method():
template = PromptTemplate("Hello, {name}!") template = PromptTemplate("Hello, {name}!")
prompt = BasePromptComponent(template) prompt = BasePromptComponent(template=template)
prompt.set(name="Alice") prompt.set(name="Alice")
assert prompt.name == "Alice" assert prompt.name == "Alice"