[AUR-395] Adopt Example1 disclaimer pipeline (#42)

* Adopt Example1 disclaimer pipeline
* Update Document class
* Add composite components
* Modify Extractor behaviours
This commit is contained in:
ian_Cin
2023-10-10 15:42:48 +07:00
committed by GitHub
parent 79cc60e6a2
commit 84f1fa8cbd
12 changed files with 654 additions and 37 deletions

View File

@@ -0,0 +1,9 @@
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .linear import GatedLinearPipeline, SimpleLinearPipeline
__all__ = [
"SimpleLinearPipeline",
"GatedLinearPipeline",
"SimpleBranchingPipeline",
"GatedBranchingPipeline",
]

View File

@@ -0,0 +1,182 @@
from typing import List, Optional
from theflow import Param
from kotaemon.base import BaseComponent
from kotaemon.composite.linear import GatedLinearPipeline
from kotaemon.documents.base import Document
class SimpleBranchingPipeline(BaseComponent):
"""
A simple branching pipeline for executing multiple branches.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example Usage:
from kotaemon.composite import GatedLinearPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
pipeline = SimpleBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
print(pipeline(condition_text="12"))
"""
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
def add_branch(self, component: BaseComponent):
"""
Add a new branch to the pipeline.
Args:
component (BaseComponent): The branch component to be added.
"""
self.branches.append(component)
def run(self, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the outputs as a list.
Args:
**prompt_kwargs: Keyword arguments for the branches.
Returns:
List: The outputs of each branch as a list.
"""
output = []
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output.append(branch(**prompt_kwargs))
return output
class GatedBranchingPipeline(SimpleBranchingPipeline):
"""
A simple gated branching pipeline for executing multiple branches based on a
condition.
This class extends the SimpleBranchingPipeline class and adds the ability to execute
the branches until a branch returns a non-empty output based on a condition.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example Usage:
from kotaemon.composite import GatedLinearPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
"""
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the output of the first
branch that returns a non-empty output based on the provided condition.
Args:
condition_text (str): The condition text to evaluate for each branch.
Default to None.
**prompt_kwargs: Keyword arguments for the branches.
Returns:
Union[OutputType, None]: The output of the first branch that satisfies the
condition, or None if no branch satisfies the condition.
Raise:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided.")
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output = branch(condition_text=condition_text, **prompt_kwargs)
if output:
return output
return Document(None)
if __name__ == "__main__":
import dotenv
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
secrets = dotenv.dotenv_values(".env")
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
pipeline(condition_text="1")

View File

@@ -0,0 +1,153 @@
from typing import Any, Callable, Optional, Union
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document, IO_Type
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.prompt.base import BasePromptComponent
class SimpleLinearPipeline(BaseComponent):
"""
A simple pipeline for running a function with a prompt, a language model, and an
optional post-processor.
Attributes:
prompt (BasePromptComponent): The prompt component used to generate the initial
input.
llm (Union[ChatLLM, LLM]): The language model component used to generate the
output.
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
post-processor component or function.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = SimpleLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
llm=llm,
post_processor=identity,
)
print(pipeline(word="lone"))
"""
prompt: BasePromptComponent
llm: Union[ChatLLM, LLM]
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
def run(
self,
*,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
):
"""
Run the function with the given arguments and return the final output as a
Document object.
Args:
llm_kwargs (dict): Keyword arguments for the llm call.
post_processor_kwargs (dict): Keyword arguments for the post_processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the function as a Document object.
"""
prompt = self.prompt(**prompt_kwargs)
llm_output = self.llm(prompt.text, **llm_kwargs)
if self.post_processor is not None:
final_output = self.post_processor(llm_output, **post_processor_kwargs)
else:
final_output = llm_output
return Document(final_output)
class GatedLinearPipeline(SimpleLinearPipeline):
"""
A pipeline that extends the SimpleLinearPipeline class and adds a condition
attribute.
Attributes:
condition (Callable[[IO_Type], Any]): A callable function that represents the
condition.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = GatedLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
condition=RegexExtractor(pattern="some pattern"),
llm=llm,
post_processor=identity,
)
print(pipeline(condition_text="some pattern", word="lone"))
print(pipeline(condition_text="other pattern", word="lone"))
"""
condition: Callable[[IO_Type], Any]
def run(
self,
*,
condition_text: Optional[str] = None,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
) -> Document:
"""
Run the pipeline with the given arguments and return the final output as a
Document object.
Args:
condition_text (str): The condition text to evaluate. Default to None.
llm_kwargs (dict): Additional keyword arguments for the language model call.
post_processor_kwargs (dict): Additional keyword arguments for the
post-processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the pipeline as a Document object.
Raises:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided")
if self.condition(condition_text):
return super().run(
llm_kwargs=llm_kwargs,
post_processor_kwargs=post_processor_kwargs,
**prompt_kwargs,
)
return Document(None)

View File

@@ -1,12 +1,43 @@
from typing import Any, Optional
from haystack.schema import Document as HaystackDocument
from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument
from pyparsing import TypeVar
IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon"
class Document(BaseDocument):
"""Base document class, mostly inherited from Document class from llama-index"""
"""
Base document class, mostly inherited from Document class from llama-index.
This class accept one positional argument `content` of an arbitrary type, which will
store the raw content of the document. If specified, the class will use
`content` to initialize the base llama_index class.
"""
content: Any
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:
if kwargs.get("text", None) is not None:
kwargs["content"] = kwargs["text"]
elif kwargs.get("embedding", None) is not None:
kwargs["content"] = kwargs["embedding"]
elif isinstance(content, Document):
kwargs = content.dict()
else:
kwargs["content"] = content
if content:
kwargs["text"] = str(content)
else:
kwargs["text"] = ""
super().__init__(*args, **kwargs)
def __bool__(self):
return bool(self.content)
@classmethod
def example(cls) -> "Document":
@@ -23,7 +54,7 @@ class Document(BaseDocument):
return HaystackDocument(content=text, meta=metadata)
def __str__(self):
return self.text
return str(self.content)
class RetrievedDocument(Document):

View File

@@ -1,22 +1,42 @@
import re
from typing import Dict, List
from typing import Callable, Dict, List, Union
from theflow import Param
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document
class ExtractorOutput(Document):
"""
Represents the output of an extractor.
"""
matches: List[str]
class RegexExtractor(BaseComponent):
"""
Simple class for extracting text from a document using a regex pattern.
Args:
pattern (str): The regex pattern to use.
pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None.
"""
pattern: str
output_map: Dict[str, str] = {}
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: List[str]
output_map: Union[Dict[str, str], Callable[[str], str]] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: Union[str, List[str]], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod
def run_raw_static(pattern: str, text: str) -> List[str]:
@@ -50,28 +70,34 @@ class RegexExtractor(BaseComponent):
if not output_map:
return text
return str(output_map.get(text, text))
if isinstance(output_map, dict):
return output_map.get(text, text)
def run_raw(self, text: str) -> List[Document]:
return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
"""
Runs the raw text through the static pattern and output mapping, returning a
list of strings.
Matches the raw text against the pattern and rans the output mapping, returning
an instance of ExtractorOutput.
Args:
text (str): The raw text to be processed.
Returns:
List[str]: The processed output as a list of strings.
ExtractorOutput: The processed output as a list of ExtractorOutput.
"""
output = self.run_raw_static(self.pattern, text)
output = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
) # type: List[str]
output = [self.map_output(text, self.output_map) for text in output]
return [
Document(text=text, metadata={"origin": "RegexExtractor"})
for text in output
]
return ExtractorOutput(
text=output[0] if output else "",
matches=output,
metadata={"origin": "RegexExtractor"},
)
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]:
"""
Runs a batch of raw text inputs through the `run_raw()` method and returns the
output for each input.
@@ -80,29 +106,28 @@ class RegexExtractor(BaseComponent):
text_batch (List[str]): A list of raw text inputs to process.
Returns:
List[List[str]]: A list of lists containing the output for each input in the
List[ExtractorOutput]: A list containing the output for each input in the
batch.
"""
batch_output = [self.run_raw(each_text) for each_text in text_batch]
return batch_output
def run_document(self, document: Document) -> List[Document]:
def run_document(self, document: Document) -> ExtractorOutput:
"""
Run the document through the regex extractor and return a list of extracted
documents.
Run the document through the regex extractor and return an extracted document.
Args:
document (Document): The input document.
Returns:
List[Document]: A list of extracted documents.
ExtractorOutput: The extracted content.
"""
return self.run_raw(document.text)
def run_batch_document(
self, document_batch: List[Document]
) -> List[List[Document]]:
) -> List[ExtractorOutput]:
"""
Runs a batch of documents through the `run_document` function and returns the
output for each document.
@@ -113,15 +138,15 @@ class RegexExtractor(BaseComponent):
batch of documents to process.
Returns:
List[List[Document]]: A list of lists where each inner list contains the
output Document for each input Document in the batch.
List[ExtractorOutput]: A list contains the output ExtractorOutput for each
input Document in the batch.
Example:
document1 = Document(...)
document2 = Document(...)
document_batch = [document1, document2]
batch_output = self.run_batch_document(document_batch)
# batch_output will be [[output1_document1, ...], [output1_document2, ...]]
# batch_output will be [output1_document1, output1_document2]
"""
batch_output = [
@@ -162,3 +187,22 @@ class RegexExtractor(BaseComponent):
return True
return False
class FirstMatchRegexExtractor(RegexExtractor):
pattern: List[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)

View File

@@ -15,6 +15,9 @@ class BasePromptComponent(BaseComponent):
given template.
"""
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__()
self.template = (