[AUR-395] Adopt Example1 disclaimer pipeline (#42)
* Adopt Example1 disclaimer pipeline * Update Document class * Add composite components * Modify Extractor behaviours
This commit is contained in:
parent
79cc60e6a2
commit
84f1fa8cbd
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -458,3 +458,4 @@ logs/
|
||||||
|
|
||||||
S.gpg-agent*
|
S.gpg-agent*
|
||||||
.vscode/settings.json
|
.vscode/settings.json
|
||||||
|
examples/example1/assets
|
||||||
|
|
|
@ -48,5 +48,10 @@ repos:
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"]
|
additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"]
|
||||||
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
args:
|
||||||
|
[
|
||||||
|
"--check-untyped-defs",
|
||||||
|
"--ignore-missing-imports",
|
||||||
|
"--new-type-inference",
|
||||||
|
]
|
||||||
exclude: "^templates/"
|
exclude: "^templates/"
|
||||||
|
|
9
knowledgehub/composite/__init__.py
Normal file
9
knowledgehub/composite/__init__.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
|
||||||
|
from .linear import GatedLinearPipeline, SimpleLinearPipeline
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SimpleLinearPipeline",
|
||||||
|
"GatedLinearPipeline",
|
||||||
|
"SimpleBranchingPipeline",
|
||||||
|
"GatedBranchingPipeline",
|
||||||
|
]
|
182
knowledgehub/composite/branching.py
Normal file
182
knowledgehub/composite/branching.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from theflow import Param
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.composite.linear import GatedLinearPipeline
|
||||||
|
from kotaemon.documents.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleBranchingPipeline(BaseComponent):
|
||||||
|
"""
|
||||||
|
A simple branching pipeline for executing multiple branches.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
branches (List[BaseComponent]): The list of branches to be executed.
|
||||||
|
|
||||||
|
Example Usage:
|
||||||
|
from kotaemon.composite import GatedLinearPipeline
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.post_processing.extractor import RegexExtractor
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
pipeline = SimpleBranchingPipeline()
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="your openai api base",
|
||||||
|
openai_api_key="your openai api key",
|
||||||
|
openai_api_version="your openai api version",
|
||||||
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
pipeline.add_branch(
|
||||||
|
GatedLinearPipeline(
|
||||||
|
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||||
|
condition=RegexExtractor(pattern=f"{i}"),
|
||||||
|
llm=llm,
|
||||||
|
post_processor=identity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(pipeline(condition_text="1"))
|
||||||
|
print(pipeline(condition_text="2"))
|
||||||
|
print(pipeline(condition_text="12"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
|
||||||
|
|
||||||
|
def add_branch(self, component: BaseComponent):
|
||||||
|
"""
|
||||||
|
Add a new branch to the pipeline.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component (BaseComponent): The branch component to be added.
|
||||||
|
"""
|
||||||
|
self.branches.append(component)
|
||||||
|
|
||||||
|
def run(self, **prompt_kwargs):
|
||||||
|
"""
|
||||||
|
Execute the pipeline by running each branch and return the outputs as a list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**prompt_kwargs: Keyword arguments for the branches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List: The outputs of each branch as a list.
|
||||||
|
"""
|
||||||
|
output = []
|
||||||
|
for i, branch in enumerate(self.branches):
|
||||||
|
self._prepare_child(branch, name=f"branch-{i}")
|
||||||
|
output.append(branch(**prompt_kwargs))
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class GatedBranchingPipeline(SimpleBranchingPipeline):
|
||||||
|
"""
|
||||||
|
A simple gated branching pipeline for executing multiple branches based on a
|
||||||
|
condition.
|
||||||
|
|
||||||
|
This class extends the SimpleBranchingPipeline class and adds the ability to execute
|
||||||
|
the branches until a branch returns a non-empty output based on a condition.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
branches (List[BaseComponent]): The list of branches to be executed.
|
||||||
|
|
||||||
|
Example Usage:
|
||||||
|
from kotaemon.composite import GatedLinearPipeline
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.post_processing.extractor import RegexExtractor
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
pipeline = GatedBranchingPipeline()
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="your openai api base",
|
||||||
|
openai_api_key="your openai api key",
|
||||||
|
openai_api_version="your openai api version",
|
||||||
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
pipeline.add_branch(
|
||||||
|
GatedLinearPipeline(
|
||||||
|
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||||
|
condition=RegexExtractor(pattern=f"{i}"),
|
||||||
|
llm=llm,
|
||||||
|
post_processor=identity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(pipeline(condition_text="1"))
|
||||||
|
print(pipeline(condition_text="2"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
|
||||||
|
"""
|
||||||
|
Execute the pipeline by running each branch and return the output of the first
|
||||||
|
branch that returns a non-empty output based on the provided condition.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition_text (str): The condition text to evaluate for each branch.
|
||||||
|
Default to None.
|
||||||
|
**prompt_kwargs: Keyword arguments for the branches.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[OutputType, None]: The output of the first branch that satisfies the
|
||||||
|
condition, or None if no branch satisfies the condition.
|
||||||
|
|
||||||
|
Raise:
|
||||||
|
ValueError: If condition_text is None
|
||||||
|
"""
|
||||||
|
if condition_text is None:
|
||||||
|
raise ValueError("`condition_text` must be provided.")
|
||||||
|
|
||||||
|
for i, branch in enumerate(self.branches):
|
||||||
|
self._prepare_child(branch, name=f"branch-{i}")
|
||||||
|
output = branch(condition_text=condition_text, **prompt_kwargs)
|
||||||
|
if output:
|
||||||
|
return output
|
||||||
|
|
||||||
|
return Document(None)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import dotenv
|
||||||
|
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.post_processing.extractor import RegexExtractor
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
secrets = dotenv.dotenv_values(".env")
|
||||||
|
|
||||||
|
pipeline = GatedBranchingPipeline()
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
|
||||||
|
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
|
||||||
|
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
|
||||||
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
pipeline.add_branch(
|
||||||
|
GatedLinearPipeline(
|
||||||
|
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
|
||||||
|
condition=RegexExtractor(pattern=f"{i}"),
|
||||||
|
llm=llm,
|
||||||
|
post_processor=identity,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pipeline(condition_text="1")
|
153
knowledgehub/composite/linear.py
Normal file
153
knowledgehub/composite/linear.py
Normal file
|
@ -0,0 +1,153 @@
|
||||||
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.documents.base import Document, IO_Type
|
||||||
|
from kotaemon.llms.chats.base import ChatLLM
|
||||||
|
from kotaemon.llms.completions.base import LLM
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleLinearPipeline(BaseComponent):
|
||||||
|
"""
|
||||||
|
A simple pipeline for running a function with a prompt, a language model, and an
|
||||||
|
optional post-processor.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt (BasePromptComponent): The prompt component used to generate the initial
|
||||||
|
input.
|
||||||
|
llm (Union[ChatLLM, LLM]): The language model component used to generate the
|
||||||
|
output.
|
||||||
|
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
|
||||||
|
post-processor component or function.
|
||||||
|
|
||||||
|
Example Usage:
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="your openai api base",
|
||||||
|
openai_api_key="your openai api key",
|
||||||
|
openai_api_version="your openai api version",
|
||||||
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = SimpleLinearPipeline(
|
||||||
|
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
|
||||||
|
llm=llm,
|
||||||
|
post_processor=identity,
|
||||||
|
)
|
||||||
|
print(pipeline(word="lone"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: BasePromptComponent
|
||||||
|
llm: Union[ChatLLM, LLM]
|
||||||
|
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
llm_kwargs: Optional[dict] = {},
|
||||||
|
post_processor_kwargs: Optional[dict] = {},
|
||||||
|
**prompt_kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Run the function with the given arguments and return the final output as a
|
||||||
|
Document object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
llm_kwargs (dict): Keyword arguments for the llm call.
|
||||||
|
post_processor_kwargs (dict): Keyword arguments for the post_processor.
|
||||||
|
**prompt_kwargs: Keyword arguments for populating the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document: The final output of the function as a Document object.
|
||||||
|
"""
|
||||||
|
prompt = self.prompt(**prompt_kwargs)
|
||||||
|
llm_output = self.llm(prompt.text, **llm_kwargs)
|
||||||
|
if self.post_processor is not None:
|
||||||
|
final_output = self.post_processor(llm_output, **post_processor_kwargs)
|
||||||
|
else:
|
||||||
|
final_output = llm_output
|
||||||
|
|
||||||
|
return Document(final_output)
|
||||||
|
|
||||||
|
|
||||||
|
class GatedLinearPipeline(SimpleLinearPipeline):
|
||||||
|
"""
|
||||||
|
A pipeline that extends the SimpleLinearPipeline class and adds a condition
|
||||||
|
attribute.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
condition (Callable[[IO_Type], Any]): A callable function that represents the
|
||||||
|
condition.
|
||||||
|
|
||||||
|
Example Usage:
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.post_processing.extractor import RegexExtractor
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
def identity(x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
llm = AzureChatOpenAI(
|
||||||
|
openai_api_base="your openai api base",
|
||||||
|
openai_api_key="your openai api key",
|
||||||
|
openai_api_version="your openai api version",
|
||||||
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = GatedLinearPipeline(
|
||||||
|
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
|
||||||
|
condition=RegexExtractor(pattern="some pattern"),
|
||||||
|
llm=llm,
|
||||||
|
post_processor=identity,
|
||||||
|
)
|
||||||
|
print(pipeline(condition_text="some pattern", word="lone"))
|
||||||
|
print(pipeline(condition_text="other pattern", word="lone"))
|
||||||
|
"""
|
||||||
|
|
||||||
|
condition: Callable[[IO_Type], Any]
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
condition_text: Optional[str] = None,
|
||||||
|
llm_kwargs: Optional[dict] = {},
|
||||||
|
post_processor_kwargs: Optional[dict] = {},
|
||||||
|
**prompt_kwargs,
|
||||||
|
) -> Document:
|
||||||
|
"""
|
||||||
|
Run the pipeline with the given arguments and return the final output as a
|
||||||
|
Document object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
condition_text (str): The condition text to evaluate. Default to None.
|
||||||
|
llm_kwargs (dict): Additional keyword arguments for the language model call.
|
||||||
|
post_processor_kwargs (dict): Additional keyword arguments for the
|
||||||
|
post-processor.
|
||||||
|
**prompt_kwargs: Keyword arguments for populating the prompt.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Document: The final output of the pipeline as a Document object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If condition_text is None
|
||||||
|
"""
|
||||||
|
if condition_text is None:
|
||||||
|
raise ValueError("`condition_text` must be provided")
|
||||||
|
|
||||||
|
if self.condition(condition_text):
|
||||||
|
return super().run(
|
||||||
|
llm_kwargs=llm_kwargs,
|
||||||
|
post_processor_kwargs=post_processor_kwargs,
|
||||||
|
**prompt_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
return Document(None)
|
|
@ -1,12 +1,43 @@
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from haystack.schema import Document as HaystackDocument
|
from haystack.schema import Document as HaystackDocument
|
||||||
from llama_index.bridge.pydantic import Field
|
from llama_index.bridge.pydantic import Field
|
||||||
from llama_index.schema import Document as BaseDocument
|
from llama_index.schema import Document as BaseDocument
|
||||||
|
from pyparsing import TypeVar
|
||||||
|
|
||||||
|
IO_Type = TypeVar("IO_Type", "Document", str)
|
||||||
SAMPLE_TEXT = "A sample Document from kotaemon"
|
SAMPLE_TEXT = "A sample Document from kotaemon"
|
||||||
|
|
||||||
|
|
||||||
class Document(BaseDocument):
|
class Document(BaseDocument):
|
||||||
"""Base document class, mostly inherited from Document class from llama-index"""
|
"""
|
||||||
|
Base document class, mostly inherited from Document class from llama-index.
|
||||||
|
|
||||||
|
This class accept one positional argument `content` of an arbitrary type, which will
|
||||||
|
store the raw content of the document. If specified, the class will use
|
||||||
|
`content` to initialize the base llama_index class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: Any
|
||||||
|
|
||||||
|
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
|
||||||
|
if content is None:
|
||||||
|
if kwargs.get("text", None) is not None:
|
||||||
|
kwargs["content"] = kwargs["text"]
|
||||||
|
elif kwargs.get("embedding", None) is not None:
|
||||||
|
kwargs["content"] = kwargs["embedding"]
|
||||||
|
elif isinstance(content, Document):
|
||||||
|
kwargs = content.dict()
|
||||||
|
else:
|
||||||
|
kwargs["content"] = content
|
||||||
|
if content:
|
||||||
|
kwargs["text"] = str(content)
|
||||||
|
else:
|
||||||
|
kwargs["text"] = ""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self.content)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def example(cls) -> "Document":
|
def example(cls) -> "Document":
|
||||||
|
@ -23,7 +54,7 @@ class Document(BaseDocument):
|
||||||
return HaystackDocument(content=text, meta=metadata)
|
return HaystackDocument(content=text, meta=metadata)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.text
|
return str(self.content)
|
||||||
|
|
||||||
|
|
||||||
class RetrievedDocument(Document):
|
class RetrievedDocument(Document):
|
||||||
|
|
|
@ -1,22 +1,42 @@
|
||||||
import re
|
import re
|
||||||
from typing import Dict, List
|
from typing import Callable, Dict, List, Union
|
||||||
|
|
||||||
|
from theflow import Param
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent
|
||||||
from kotaemon.documents.base import Document
|
from kotaemon.documents.base import Document
|
||||||
|
|
||||||
|
|
||||||
|
class ExtractorOutput(Document):
|
||||||
|
"""
|
||||||
|
Represents the output of an extractor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
matches: List[str]
|
||||||
|
|
||||||
|
|
||||||
class RegexExtractor(BaseComponent):
|
class RegexExtractor(BaseComponent):
|
||||||
"""
|
"""
|
||||||
Simple class for extracting text from a document using a regex pattern.
|
Simple class for extracting text from a document using a regex pattern.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
pattern (str): The regex pattern to use.
|
pattern (List[str]): The regex pattern(s) to use.
|
||||||
output_map (dict, optional): A mapping from extracted text to the
|
output_map (dict, optional): A mapping from extracted text to the
|
||||||
desired output. Defaults to None.
|
desired output. Defaults to None.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
pattern: str
|
class Config:
|
||||||
output_map: Dict[str, str] = {}
|
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
|
||||||
|
|
||||||
|
pattern: List[str]
|
||||||
|
output_map: Union[Dict[str, str], Callable[[str], str]] = Param(
|
||||||
|
default_callback=lambda *_: {}
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, pattern: Union[str, List[str]], **kwargs):
|
||||||
|
if isinstance(pattern, str):
|
||||||
|
pattern = [pattern]
|
||||||
|
super().__init__(pattern=pattern, **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def run_raw_static(pattern: str, text: str) -> List[str]:
|
def run_raw_static(pattern: str, text: str) -> List[str]:
|
||||||
|
@ -50,28 +70,34 @@ class RegexExtractor(BaseComponent):
|
||||||
if not output_map:
|
if not output_map:
|
||||||
return text
|
return text
|
||||||
|
|
||||||
return str(output_map.get(text, text))
|
if isinstance(output_map, dict):
|
||||||
|
return output_map.get(text, text)
|
||||||
|
|
||||||
def run_raw(self, text: str) -> List[Document]:
|
return output_map(text)
|
||||||
|
|
||||||
|
def run_raw(self, text: str) -> ExtractorOutput:
|
||||||
"""
|
"""
|
||||||
Runs the raw text through the static pattern and output mapping, returning a
|
Matches the raw text against the pattern and rans the output mapping, returning
|
||||||
list of strings.
|
an instance of ExtractorOutput.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text (str): The raw text to be processed.
|
text (str): The raw text to be processed.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[str]: The processed output as a list of strings.
|
ExtractorOutput: The processed output as a list of ExtractorOutput.
|
||||||
"""
|
"""
|
||||||
output = self.run_raw_static(self.pattern, text)
|
output = sum(
|
||||||
|
[self.run_raw_static(p, text) for p in self.pattern], []
|
||||||
|
) # type: List[str]
|
||||||
output = [self.map_output(text, self.output_map) for text in output]
|
output = [self.map_output(text, self.output_map) for text in output]
|
||||||
|
|
||||||
return [
|
return ExtractorOutput(
|
||||||
Document(text=text, metadata={"origin": "RegexExtractor"})
|
text=output[0] if output else "",
|
||||||
for text in output
|
matches=output,
|
||||||
]
|
metadata={"origin": "RegexExtractor"},
|
||||||
|
)
|
||||||
|
|
||||||
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]:
|
def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]:
|
||||||
"""
|
"""
|
||||||
Runs a batch of raw text inputs through the `run_raw()` method and returns the
|
Runs a batch of raw text inputs through the `run_raw()` method and returns the
|
||||||
output for each input.
|
output for each input.
|
||||||
|
@ -80,29 +106,28 @@ class RegexExtractor(BaseComponent):
|
||||||
text_batch (List[str]): A list of raw text inputs to process.
|
text_batch (List[str]): A list of raw text inputs to process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[str]]: A list of lists containing the output for each input in the
|
List[ExtractorOutput]: A list containing the output for each input in the
|
||||||
batch.
|
batch.
|
||||||
"""
|
"""
|
||||||
batch_output = [self.run_raw(each_text) for each_text in text_batch]
|
batch_output = [self.run_raw(each_text) for each_text in text_batch]
|
||||||
|
|
||||||
return batch_output
|
return batch_output
|
||||||
|
|
||||||
def run_document(self, document: Document) -> List[Document]:
|
def run_document(self, document: Document) -> ExtractorOutput:
|
||||||
"""
|
"""
|
||||||
Run the document through the regex extractor and return a list of extracted
|
Run the document through the regex extractor and return an extracted document.
|
||||||
documents.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
document (Document): The input document.
|
document (Document): The input document.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[Document]: A list of extracted documents.
|
ExtractorOutput: The extracted content.
|
||||||
"""
|
"""
|
||||||
return self.run_raw(document.text)
|
return self.run_raw(document.text)
|
||||||
|
|
||||||
def run_batch_document(
|
def run_batch_document(
|
||||||
self, document_batch: List[Document]
|
self, document_batch: List[Document]
|
||||||
) -> List[List[Document]]:
|
) -> List[ExtractorOutput]:
|
||||||
"""
|
"""
|
||||||
Runs a batch of documents through the `run_document` function and returns the
|
Runs a batch of documents through the `run_document` function and returns the
|
||||||
output for each document.
|
output for each document.
|
||||||
|
@ -113,15 +138,15 @@ class RegexExtractor(BaseComponent):
|
||||||
batch of documents to process.
|
batch of documents to process.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[Document]]: A list of lists where each inner list contains the
|
List[ExtractorOutput]: A list contains the output ExtractorOutput for each
|
||||||
output Document for each input Document in the batch.
|
input Document in the batch.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
document1 = Document(...)
|
document1 = Document(...)
|
||||||
document2 = Document(...)
|
document2 = Document(...)
|
||||||
document_batch = [document1, document2]
|
document_batch = [document1, document2]
|
||||||
batch_output = self.run_batch_document(document_batch)
|
batch_output = self.run_batch_document(document_batch)
|
||||||
# batch_output will be [[output1_document1, ...], [output1_document2, ...]]
|
# batch_output will be [output1_document1, output1_document2]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
batch_output = [
|
batch_output = [
|
||||||
|
@ -162,3 +187,22 @@ class RegexExtractor(BaseComponent):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class FirstMatchRegexExtractor(RegexExtractor):
|
||||||
|
pattern: List[str]
|
||||||
|
|
||||||
|
def run_raw(self, text: str) -> ExtractorOutput:
|
||||||
|
for p in self.pattern:
|
||||||
|
output = self.run_raw_static(p, text)
|
||||||
|
if output:
|
||||||
|
output = [self.map_output(text, self.output_map) for text in output]
|
||||||
|
return ExtractorOutput(
|
||||||
|
text=output[0],
|
||||||
|
matches=output,
|
||||||
|
metadata={"origin": "FirstMatchRegexExtractor"},
|
||||||
|
)
|
||||||
|
|
||||||
|
return ExtractorOutput(
|
||||||
|
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
|
||||||
|
)
|
||||||
|
|
|
@ -15,6 +15,9 @@ class BasePromptComponent(BaseComponent):
|
||||||
given template.
|
given template.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
|
||||||
|
|
||||||
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
def __init__(self, template: Union[str, PromptTemplate], **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.template = (
|
self.template = (
|
||||||
|
|
141
tests/test_composite.py
Normal file
141
tests/test_composite.py
Normal file
|
@ -0,0 +1,141 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kotaemon.composite import (
|
||||||
|
GatedBranchingPipeline,
|
||||||
|
GatedLinearPipeline,
|
||||||
|
SimpleBranchingPipeline,
|
||||||
|
SimpleLinearPipeline,
|
||||||
|
)
|
||||||
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
||||||
|
from kotaemon.post_processing.extractor import RegexExtractor
|
||||||
|
from kotaemon.prompt.base import BasePromptComponent
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_llm():
|
||||||
|
return AzureChatOpenAI(
|
||||||
|
openai_api_base="OPENAI_API_BASE",
|
||||||
|
openai_api_key="OPENAI_API_KEY",
|
||||||
|
openai_api_version="OPENAI_API_VERSION",
|
||||||
|
deployment_name="dummy-q2-gpt35",
|
||||||
|
temperature=0,
|
||||||
|
request_timeout=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_post_processor():
|
||||||
|
return RegexExtractor(pattern=r"\d+")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_prompt():
|
||||||
|
return BasePromptComponent(template="Test prompt {value}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor):
|
||||||
|
return SimpleLinearPipeline(
|
||||||
|
prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor):
|
||||||
|
return GatedLinearPipeline(
|
||||||
|
prompt=mock_prompt,
|
||||||
|
llm=mock_llm,
|
||||||
|
post_processor=mock_post_processor,
|
||||||
|
condition=RegexExtractor(pattern="positive"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor):
|
||||||
|
return GatedLinearPipeline(
|
||||||
|
prompt=mock_prompt,
|
||||||
|
llm=mock_llm,
|
||||||
|
post_processor=mock_post_processor,
|
||||||
|
condition=RegexExtractor(pattern="negative"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
|
||||||
|
openai_mocker = mocker.patch.object(
|
||||||
|
AzureChatOpenAI, "run", return_value="This is a test 123"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mock_simple_linear_pipeline.run(value="abc")
|
||||||
|
|
||||||
|
assert result.text == "123"
|
||||||
|
assert openai_mocker.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_gated_linear_pipeline_run_positive(
|
||||||
|
mocker, mock_gated_linear_pipeline_positive
|
||||||
|
):
|
||||||
|
openai_mocker = mocker.patch.object(
|
||||||
|
AzureChatOpenAI, "run", return_value="This is a test 123."
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mock_gated_linear_pipeline_positive.run(
|
||||||
|
value="abc", condition_text="positive condition"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.text == "123"
|
||||||
|
assert openai_mocker.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_gated_linear_pipeline_run_negative(
|
||||||
|
mocker, mock_gated_linear_pipeline_positive
|
||||||
|
):
|
||||||
|
openai_mocker = mocker.patch.object(
|
||||||
|
AzureChatOpenAI, "run", return_value="This is a test 123."
|
||||||
|
)
|
||||||
|
|
||||||
|
result = mock_gated_linear_pipeline_positive.run(
|
||||||
|
value="abc", condition_text="negative condition"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.content is None
|
||||||
|
assert openai_mocker.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
|
||||||
|
openai_mocker = mocker.patch.object(
|
||||||
|
AzureChatOpenAI,
|
||||||
|
"run",
|
||||||
|
side_effect=[
|
||||||
|
"This is a test 123.",
|
||||||
|
"a quick brown fox",
|
||||||
|
"jumps over the lazy dog 456",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
pipeline = SimpleBranchingPipeline()
|
||||||
|
for _ in range(3):
|
||||||
|
pipeline.add_branch(mock_simple_linear_pipeline)
|
||||||
|
|
||||||
|
result = pipeline.run(value="abc")
|
||||||
|
texts = [each.text for each in result]
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
|
assert texts == ["123", "", "456"]
|
||||||
|
assert openai_mocker.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_gated_branching_pipeline_run(
|
||||||
|
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
|
||||||
|
):
|
||||||
|
openai_mocker = mocker.patch.object(
|
||||||
|
AzureChatOpenAI, "run", return_value="a quick brown fox"
|
||||||
|
)
|
||||||
|
pipeline = GatedBranchingPipeline()
|
||||||
|
|
||||||
|
pipeline.add_branch(mock_gated_linear_pipeline_negative)
|
||||||
|
pipeline.add_branch(mock_gated_linear_pipeline_positive)
|
||||||
|
pipeline.add_branch(mock_gated_linear_pipeline_positive)
|
||||||
|
|
||||||
|
result = pipeline.run(value="abc", condition_text="positive condition")
|
||||||
|
|
||||||
|
assert result.text == ""
|
||||||
|
assert openai_mocker.call_count == 2
|
49
tests/test_documents.py
Normal file
49
tests/test_documents.py
Normal file
|
@ -0,0 +1,49 @@
|
||||||
|
from haystack.schema import Document as HaystackDocument
|
||||||
|
|
||||||
|
from kotaemon.documents.base import Document, RetrievedDocument
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_constructor_with_builtin_types():
|
||||||
|
for value in ["str", 1, {}, set(), [], tuple, None]:
|
||||||
|
doc = Document(value)
|
||||||
|
assert doc.text == (str(value) if value else "")
|
||||||
|
assert doc.content == value
|
||||||
|
assert bool(doc) == bool(value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_constructor_with_document():
|
||||||
|
text = "Sample text"
|
||||||
|
doc1 = Document(text)
|
||||||
|
doc2 = Document(doc1)
|
||||||
|
assert doc2.text == doc1.text
|
||||||
|
assert doc2.content == doc1.content
|
||||||
|
|
||||||
|
|
||||||
|
def test_document_to_haystack_format():
|
||||||
|
text = "Sample text"
|
||||||
|
metadata = {"filename": "sample.txt"}
|
||||||
|
doc = Document(text, metadata=metadata)
|
||||||
|
haystack_doc = doc.to_haystack_format()
|
||||||
|
assert isinstance(haystack_doc, HaystackDocument)
|
||||||
|
assert haystack_doc.content == doc.text
|
||||||
|
assert haystack_doc.meta == metadata
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieved_document_default_values():
|
||||||
|
sample_text = "text"
|
||||||
|
retrieved_doc = RetrievedDocument(text=sample_text)
|
||||||
|
assert retrieved_doc.text == sample_text
|
||||||
|
assert retrieved_doc.score == 0.0
|
||||||
|
assert retrieved_doc.retrieval_metadata == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_retrieved_document_attributes():
|
||||||
|
sample_text = "text"
|
||||||
|
score = 0.8
|
||||||
|
metadata = {"source": "retrieval_system"}
|
||||||
|
retrieved_doc = RetrievedDocument(
|
||||||
|
text=sample_text, score=score, retrieval_metadata=metadata
|
||||||
|
)
|
||||||
|
assert retrieved_doc.text == sample_text
|
||||||
|
assert retrieved_doc.score == score
|
||||||
|
assert retrieved_doc.retrieval_metadata == metadata
|
|
@ -14,8 +14,8 @@ def regex_extractor():
|
||||||
def test_run_document(regex_extractor):
|
def test_run_document(regex_extractor):
|
||||||
document = Document(text="This is a test. 1 2 3")
|
document = Document(text="This is a test. 1 2 3")
|
||||||
extracted_document = regex_extractor(document)
|
extracted_document = regex_extractor(document)
|
||||||
extracted_texts = [each.text for each in extracted_document]
|
assert extracted_document.text == "One"
|
||||||
assert extracted_texts == ["One", "Two", "Three"]
|
assert extracted_document.matches == ["One", "Two", "Three"]
|
||||||
|
|
||||||
|
|
||||||
def test_is_document(regex_extractor):
|
def test_is_document(regex_extractor):
|
||||||
|
@ -30,11 +30,13 @@ def test_is_batch(regex_extractor):
|
||||||
|
|
||||||
def test_run_raw(regex_extractor):
|
def test_run_raw(regex_extractor):
|
||||||
output = regex_extractor("This is a test. 123")
|
output = regex_extractor("This is a test. 123")
|
||||||
output = [each.text for each in output]
|
assert output.text == "123"
|
||||||
assert output == ["123"]
|
assert output.matches == ["123"]
|
||||||
|
|
||||||
|
|
||||||
def test_run_batch_raw(regex_extractor):
|
def test_run_batch_raw(regex_extractor):
|
||||||
output = regex_extractor(["This is a test. 123", "456"])
|
output = regex_extractor(["This is a test. 123", "456"])
|
||||||
output = [[each.text for each in batch] for batch in output]
|
extracted_text = [each.text for each in output]
|
||||||
assert output == [["123"], ["456"]]
|
extracted_matches = [each.matches for each in output]
|
||||||
|
assert extracted_text == ["123", "456"]
|
||||||
|
assert extracted_matches == [["123"], ["456"]]
|
||||||
|
|
|
@ -54,10 +54,7 @@ def test_run():
|
||||||
|
|
||||||
result = prompt()
|
result = prompt()
|
||||||
|
|
||||||
assert (
|
assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One"
|
||||||
result.text
|
|
||||||
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_set_method():
|
def test_set_method():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user