[AUR-395] Adopt Example1 disclaimer pipeline (#42)

* Adopt Example1 disclaimer pipeline
* Update Document class
* Add composite components
* Modify Extractor behaviours
This commit is contained in:
ian_Cin
2023-10-10 15:42:48 +07:00
committed by GitHub
parent 79cc60e6a2
commit 84f1fa8cbd
12 changed files with 654 additions and 37 deletions

View File

@@ -0,0 +1,9 @@
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .linear import GatedLinearPipeline, SimpleLinearPipeline
__all__ = [
"SimpleLinearPipeline",
"GatedLinearPipeline",
"SimpleBranchingPipeline",
"GatedBranchingPipeline",
]

View File

@@ -0,0 +1,182 @@
from typing import List, Optional
from theflow import Param
from kotaemon.base import BaseComponent
from kotaemon.composite.linear import GatedLinearPipeline
from kotaemon.documents.base import Document
class SimpleBranchingPipeline(BaseComponent):
"""
A simple branching pipeline for executing multiple branches.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example Usage:
from kotaemon.composite import GatedLinearPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
pipeline = SimpleBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
print(pipeline(condition_text="12"))
"""
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
def add_branch(self, component: BaseComponent):
"""
Add a new branch to the pipeline.
Args:
component (BaseComponent): The branch component to be added.
"""
self.branches.append(component)
def run(self, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the outputs as a list.
Args:
**prompt_kwargs: Keyword arguments for the branches.
Returns:
List: The outputs of each branch as a list.
"""
output = []
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output.append(branch(**prompt_kwargs))
return output
class GatedBranchingPipeline(SimpleBranchingPipeline):
"""
A simple gated branching pipeline for executing multiple branches based on a
condition.
This class extends the SimpleBranchingPipeline class and adds the ability to execute
the branches until a branch returns a non-empty output based on a condition.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example Usage:
from kotaemon.composite import GatedLinearPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
"""
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the output of the first
branch that returns a non-empty output based on the provided condition.
Args:
condition_text (str): The condition text to evaluate for each branch.
Default to None.
**prompt_kwargs: Keyword arguments for the branches.
Returns:
Union[OutputType, None]: The output of the first branch that satisfies the
condition, or None if no branch satisfies the condition.
Raise:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided.")
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output = branch(condition_text=condition_text, **prompt_kwargs)
if output:
return output
return Document(None)
if __name__ == "__main__":
import dotenv
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
secrets = dotenv.dotenv_values(".env")
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
pipeline(condition_text="1")

View File

@@ -0,0 +1,153 @@
from typing import Any, Callable, Optional, Union
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document, IO_Type
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.prompt.base import BasePromptComponent
class SimpleLinearPipeline(BaseComponent):
"""
A simple pipeline for running a function with a prompt, a language model, and an
optional post-processor.
Attributes:
prompt (BasePromptComponent): The prompt component used to generate the initial
input.
llm (Union[ChatLLM, LLM]): The language model component used to generate the
output.
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
post-processor component or function.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = SimpleLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
llm=llm,
post_processor=identity,
)
print(pipeline(word="lone"))
"""
prompt: BasePromptComponent
llm: Union[ChatLLM, LLM]
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
def run(
self,
*,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
):
"""
Run the function with the given arguments and return the final output as a
Document object.
Args:
llm_kwargs (dict): Keyword arguments for the llm call.
post_processor_kwargs (dict): Keyword arguments for the post_processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the function as a Document object.
"""
prompt = self.prompt(**prompt_kwargs)
llm_output = self.llm(prompt.text, **llm_kwargs)
if self.post_processor is not None:
final_output = self.post_processor(llm_output, **post_processor_kwargs)
else:
final_output = llm_output
return Document(final_output)
class GatedLinearPipeline(SimpleLinearPipeline):
"""
A pipeline that extends the SimpleLinearPipeline class and adds a condition
attribute.
Attributes:
condition (Callable[[IO_Type], Any]): A callable function that represents the
condition.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = GatedLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
condition=RegexExtractor(pattern="some pattern"),
llm=llm,
post_processor=identity,
)
print(pipeline(condition_text="some pattern", word="lone"))
print(pipeline(condition_text="other pattern", word="lone"))
"""
condition: Callable[[IO_Type], Any]
def run(
self,
*,
condition_text: Optional[str] = None,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
) -> Document:
"""
Run the pipeline with the given arguments and return the final output as a
Document object.
Args:
condition_text (str): The condition text to evaluate. Default to None.
llm_kwargs (dict): Additional keyword arguments for the language model call.
post_processor_kwargs (dict): Additional keyword arguments for the
post-processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the pipeline as a Document object.
Raises:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided")
if self.condition(condition_text):
return super().run(
llm_kwargs=llm_kwargs,
post_processor_kwargs=post_processor_kwargs,
**prompt_kwargs,
)
return Document(None)