[AUR-395] Adopt Example1 disclaimer pipeline (#42)

* Adopt Example1 disclaimer pipeline
* Update Document class
* Add composite components
* Modify Extractor behaviours
This commit is contained in:
ian_Cin 2023-10-10 15:42:48 +07:00 committed by GitHub
parent 79cc60e6a2
commit 84f1fa8cbd
12 changed files with 654 additions and 37 deletions

1
.gitignore vendored
View File

@ -458,3 +458,4 @@ logs/
S.gpg-agent* S.gpg-agent*
.vscode/settings.json .vscode/settings.json
examples/example1/assets

View File

@ -48,5 +48,10 @@ repos:
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"] additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"]
args: ["--check-untyped-defs", "--ignore-missing-imports"] args:
[
"--check-untyped-defs",
"--ignore-missing-imports",
"--new-type-inference",
]
exclude: "^templates/" exclude: "^templates/"

View File

@ -0,0 +1,9 @@
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .linear import GatedLinearPipeline, SimpleLinearPipeline
__all__ = [
"SimpleLinearPipeline",
"GatedLinearPipeline",
"SimpleBranchingPipeline",
"GatedBranchingPipeline",
]

View File

@ -0,0 +1,182 @@
from typing import List, Optional
from theflow import Param
from kotaemon.base import BaseComponent
from kotaemon.composite.linear import GatedLinearPipeline
from kotaemon.documents.base import Document
class SimpleBranchingPipeline(BaseComponent):
"""
A simple branching pipeline for executing multiple branches.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example Usage:
from kotaemon.composite import GatedLinearPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
pipeline = SimpleBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
print(pipeline(condition_text="12"))
"""
branches: List[BaseComponent] = Param(default_callback=lambda *_: [])
def add_branch(self, component: BaseComponent):
"""
Add a new branch to the pipeline.
Args:
component (BaseComponent): The branch component to be added.
"""
self.branches.append(component)
def run(self, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the outputs as a list.
Args:
**prompt_kwargs: Keyword arguments for the branches.
Returns:
List: The outputs of each branch as a list.
"""
output = []
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output.append(branch(**prompt_kwargs))
return output
class GatedBranchingPipeline(SimpleBranchingPipeline):
"""
A simple gated branching pipeline for executing multiple branches based on a
condition.
This class extends the SimpleBranchingPipeline class and adds the ability to execute
the branches until a branch returns a non-empty output based on a condition.
Attributes:
branches (List[BaseComponent]): The list of branches to be executed.
Example Usage:
from kotaemon.composite import GatedLinearPipeline
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
print(pipeline(condition_text="1"))
print(pipeline(condition_text="2"))
"""
def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs):
"""
Execute the pipeline by running each branch and return the output of the first
branch that returns a non-empty output based on the provided condition.
Args:
condition_text (str): The condition text to evaluate for each branch.
Default to None.
**prompt_kwargs: Keyword arguments for the branches.
Returns:
Union[OutputType, None]: The output of the first branch that satisfies the
condition, or None if no branch satisfies the condition.
Raise:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided.")
for i, branch in enumerate(self.branches):
self._prepare_child(branch, name=f"branch-{i}")
output = branch(condition_text=condition_text, **prompt_kwargs)
if output:
return output
return Document(None)
if __name__ == "__main__":
import dotenv
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
secrets = dotenv.dotenv_values(".env")
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
for i in range(3):
pipeline.add_branch(
GatedLinearPipeline(
prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"),
condition=RegexExtractor(pattern=f"{i}"),
llm=llm,
post_processor=identity,
)
)
pipeline(condition_text="1")

View File

@ -0,0 +1,153 @@
from typing import Any, Callable, Optional, Union
from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document, IO_Type
from kotaemon.llms.chats.base import ChatLLM
from kotaemon.llms.completions.base import LLM
from kotaemon.prompt.base import BasePromptComponent
class SimpleLinearPipeline(BaseComponent):
"""
A simple pipeline for running a function with a prompt, a language model, and an
optional post-processor.
Attributes:
prompt (BasePromptComponent): The prompt component used to generate the initial
input.
llm (Union[ChatLLM, LLM]): The language model component used to generate the
output.
post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional
post-processor component or function.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = SimpleLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
llm=llm,
post_processor=identity,
)
print(pipeline(word="lone"))
"""
prompt: BasePromptComponent
llm: Union[ChatLLM, LLM]
post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]]
def run(
self,
*,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
):
"""
Run the function with the given arguments and return the final output as a
Document object.
Args:
llm_kwargs (dict): Keyword arguments for the llm call.
post_processor_kwargs (dict): Keyword arguments for the post_processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the function as a Document object.
"""
prompt = self.prompt(**prompt_kwargs)
llm_output = self.llm(prompt.text, **llm_kwargs)
if self.post_processor is not None:
final_output = self.post_processor(llm_output, **post_processor_kwargs)
else:
final_output = llm_output
return Document(final_output)
class GatedLinearPipeline(SimpleLinearPipeline):
"""
A pipeline that extends the SimpleLinearPipeline class and adds a condition
attribute.
Attributes:
condition (Callable[[IO_Type], Any]): A callable function that represents the
condition.
Example Usage:
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
pipeline = GatedLinearPipeline(
prompt=BasePromptComponent(template="what is {word} in Japanese ?"),
condition=RegexExtractor(pattern="some pattern"),
llm=llm,
post_processor=identity,
)
print(pipeline(condition_text="some pattern", word="lone"))
print(pipeline(condition_text="other pattern", word="lone"))
"""
condition: Callable[[IO_Type], Any]
def run(
self,
*,
condition_text: Optional[str] = None,
llm_kwargs: Optional[dict] = {},
post_processor_kwargs: Optional[dict] = {},
**prompt_kwargs,
) -> Document:
"""
Run the pipeline with the given arguments and return the final output as a
Document object.
Args:
condition_text (str): The condition text to evaluate. Default to None.
llm_kwargs (dict): Additional keyword arguments for the language model call.
post_processor_kwargs (dict): Additional keyword arguments for the
post-processor.
**prompt_kwargs: Keyword arguments for populating the prompt.
Returns:
Document: The final output of the pipeline as a Document object.
Raises:
ValueError: If condition_text is None
"""
if condition_text is None:
raise ValueError("`condition_text` must be provided")
if self.condition(condition_text):
return super().run(
llm_kwargs=llm_kwargs,
post_processor_kwargs=post_processor_kwargs,
**prompt_kwargs,
)
return Document(None)

View File

@ -1,12 +1,43 @@
from typing import Any, Optional
from haystack.schema import Document as HaystackDocument from haystack.schema import Document as HaystackDocument
from llama_index.bridge.pydantic import Field from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument from llama_index.schema import Document as BaseDocument
from pyparsing import TypeVar
IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon" SAMPLE_TEXT = "A sample Document from kotaemon"
class Document(BaseDocument): class Document(BaseDocument):
"""Base document class, mostly inherited from Document class from llama-index""" """
Base document class, mostly inherited from Document class from llama-index.
This class accept one positional argument `content` of an arbitrary type, which will
store the raw content of the document. If specified, the class will use
`content` to initialize the base llama_index class.
"""
content: Any
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:
if kwargs.get("text", None) is not None:
kwargs["content"] = kwargs["text"]
elif kwargs.get("embedding", None) is not None:
kwargs["content"] = kwargs["embedding"]
elif isinstance(content, Document):
kwargs = content.dict()
else:
kwargs["content"] = content
if content:
kwargs["text"] = str(content)
else:
kwargs["text"] = ""
super().__init__(*args, **kwargs)
def __bool__(self):
return bool(self.content)
@classmethod @classmethod
def example(cls) -> "Document": def example(cls) -> "Document":
@ -23,7 +54,7 @@ class Document(BaseDocument):
return HaystackDocument(content=text, meta=metadata) return HaystackDocument(content=text, meta=metadata)
def __str__(self): def __str__(self):
return self.text return str(self.content)
class RetrievedDocument(Document): class RetrievedDocument(Document):

View File

@ -1,22 +1,42 @@
import re import re
from typing import Dict, List from typing import Callable, Dict, List, Union
from theflow import Param
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent
from kotaemon.documents.base import Document from kotaemon.documents.base import Document
class ExtractorOutput(Document):
"""
Represents the output of an extractor.
"""
matches: List[str]
class RegexExtractor(BaseComponent): class RegexExtractor(BaseComponent):
""" """
Simple class for extracting text from a document using a regex pattern. Simple class for extracting text from a document using a regex pattern.
Args: Args:
pattern (str): The regex pattern to use. pattern (List[str]): The regex pattern(s) to use.
output_map (dict, optional): A mapping from extracted text to the output_map (dict, optional): A mapping from extracted text to the
desired output. Defaults to None. desired output. Defaults to None.
""" """
pattern: str class Config:
output_map: Dict[str, str] = {} middleware_switches = {"theflow.middleware.CachingMiddleware": False}
pattern: List[str]
output_map: Union[Dict[str, str], Callable[[str], str]] = Param(
default_callback=lambda *_: {}
)
def __init__(self, pattern: Union[str, List[str]], **kwargs):
if isinstance(pattern, str):
pattern = [pattern]
super().__init__(pattern=pattern, **kwargs)
@staticmethod @staticmethod
def run_raw_static(pattern: str, text: str) -> List[str]: def run_raw_static(pattern: str, text: str) -> List[str]:
@ -50,28 +70,34 @@ class RegexExtractor(BaseComponent):
if not output_map: if not output_map:
return text return text
return str(output_map.get(text, text)) if isinstance(output_map, dict):
return output_map.get(text, text)
def run_raw(self, text: str) -> List[Document]: return output_map(text)
def run_raw(self, text: str) -> ExtractorOutput:
""" """
Runs the raw text through the static pattern and output mapping, returning a Matches the raw text against the pattern and rans the output mapping, returning
list of strings. an instance of ExtractorOutput.
Args: Args:
text (str): The raw text to be processed. text (str): The raw text to be processed.
Returns: Returns:
List[str]: The processed output as a list of strings. ExtractorOutput: The processed output as a list of ExtractorOutput.
""" """
output = self.run_raw_static(self.pattern, text) output = sum(
[self.run_raw_static(p, text) for p in self.pattern], []
) # type: List[str]
output = [self.map_output(text, self.output_map) for text in output] output = [self.map_output(text, self.output_map) for text in output]
return [ return ExtractorOutput(
Document(text=text, metadata={"origin": "RegexExtractor"}) text=output[0] if output else "",
for text in output matches=output,
] metadata={"origin": "RegexExtractor"},
)
def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]: def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]:
""" """
Runs a batch of raw text inputs through the `run_raw()` method and returns the Runs a batch of raw text inputs through the `run_raw()` method and returns the
output for each input. output for each input.
@ -80,29 +106,28 @@ class RegexExtractor(BaseComponent):
text_batch (List[str]): A list of raw text inputs to process. text_batch (List[str]): A list of raw text inputs to process.
Returns: Returns:
List[List[str]]: A list of lists containing the output for each input in the List[ExtractorOutput]: A list containing the output for each input in the
batch. batch.
""" """
batch_output = [self.run_raw(each_text) for each_text in text_batch] batch_output = [self.run_raw(each_text) for each_text in text_batch]
return batch_output return batch_output
def run_document(self, document: Document) -> List[Document]: def run_document(self, document: Document) -> ExtractorOutput:
""" """
Run the document through the regex extractor and return a list of extracted Run the document through the regex extractor and return an extracted document.
documents.
Args: Args:
document (Document): The input document. document (Document): The input document.
Returns: Returns:
List[Document]: A list of extracted documents. ExtractorOutput: The extracted content.
""" """
return self.run_raw(document.text) return self.run_raw(document.text)
def run_batch_document( def run_batch_document(
self, document_batch: List[Document] self, document_batch: List[Document]
) -> List[List[Document]]: ) -> List[ExtractorOutput]:
""" """
Runs a batch of documents through the `run_document` function and returns the Runs a batch of documents through the `run_document` function and returns the
output for each document. output for each document.
@ -113,15 +138,15 @@ class RegexExtractor(BaseComponent):
batch of documents to process. batch of documents to process.
Returns: Returns:
List[List[Document]]: A list of lists where each inner list contains the List[ExtractorOutput]: A list contains the output ExtractorOutput for each
output Document for each input Document in the batch. input Document in the batch.
Example: Example:
document1 = Document(...) document1 = Document(...)
document2 = Document(...) document2 = Document(...)
document_batch = [document1, document2] document_batch = [document1, document2]
batch_output = self.run_batch_document(document_batch) batch_output = self.run_batch_document(document_batch)
# batch_output will be [[output1_document1, ...], [output1_document2, ...]] # batch_output will be [output1_document1, output1_document2]
""" """
batch_output = [ batch_output = [
@ -162,3 +187,22 @@ class RegexExtractor(BaseComponent):
return True return True
return False return False
class FirstMatchRegexExtractor(RegexExtractor):
pattern: List[str]
def run_raw(self, text: str) -> ExtractorOutput:
for p in self.pattern:
output = self.run_raw_static(p, text)
if output:
output = [self.map_output(text, self.output_map) for text in output]
return ExtractorOutput(
text=output[0],
matches=output,
metadata={"origin": "FirstMatchRegexExtractor"},
)
return ExtractorOutput(
text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"}
)

View File

@ -15,6 +15,9 @@ class BasePromptComponent(BaseComponent):
given template. given template.
""" """
class Config:
middleware_switches = {"theflow.middleware.CachingMiddleware": False}
def __init__(self, template: Union[str, PromptTemplate], **kwargs): def __init__(self, template: Union[str, PromptTemplate], **kwargs):
super().__init__() super().__init__()
self.template = ( self.template = (

141
tests/test_composite.py Normal file
View File

@ -0,0 +1,141 @@
import pytest
from kotaemon.composite import (
GatedBranchingPipeline,
GatedLinearPipeline,
SimpleBranchingPipeline,
SimpleLinearPipeline,
)
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
@pytest.fixture
def mock_llm():
return AzureChatOpenAI(
openai_api_base="OPENAI_API_BASE",
openai_api_key="OPENAI_API_KEY",
openai_api_version="OPENAI_API_VERSION",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
@pytest.fixture
def mock_post_processor():
return RegexExtractor(pattern=r"\d+")
@pytest.fixture
def mock_prompt():
return BasePromptComponent(template="Test prompt {value}")
@pytest.fixture
def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor):
return SimpleLinearPipeline(
prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor
)
@pytest.fixture
def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor):
return GatedLinearPipeline(
prompt=mock_prompt,
llm=mock_llm,
post_processor=mock_post_processor,
condition=RegexExtractor(pattern="positive"),
)
@pytest.fixture
def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor):
return GatedLinearPipeline(
prompt=mock_prompt,
llm=mock_llm,
post_processor=mock_post_processor,
condition=RegexExtractor(pattern="negative"),
)
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="This is a test 123"
)
result = mock_simple_linear_pipeline.run(value="abc")
assert result.text == "123"
assert openai_mocker.call_count == 1
def test_gated_linear_pipeline_run_positive(
mocker, mock_gated_linear_pipeline_positive
):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="This is a test 123."
)
result = mock_gated_linear_pipeline_positive.run(
value="abc", condition_text="positive condition"
)
assert result.text == "123"
assert openai_mocker.call_count == 1
def test_gated_linear_pipeline_run_negative(
mocker, mock_gated_linear_pipeline_positive
):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="This is a test 123."
)
result = mock_gated_linear_pipeline_positive.run(
value="abc", condition_text="negative condition"
)
assert result.content is None
assert openai_mocker.call_count == 0
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch.object(
AzureChatOpenAI,
"run",
side_effect=[
"This is a test 123.",
"a quick brown fox",
"jumps over the lazy dog 456",
],
)
pipeline = SimpleBranchingPipeline()
for _ in range(3):
pipeline.add_branch(mock_simple_linear_pipeline)
result = pipeline.run(value="abc")
texts = [each.text for each in result]
assert len(result) == 3
assert texts == ["123", "", "456"]
assert openai_mocker.call_count == 3
def test_simple_gated_branching_pipeline_run(
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="a quick brown fox"
)
pipeline = GatedBranchingPipeline()
pipeline.add_branch(mock_gated_linear_pipeline_negative)
pipeline.add_branch(mock_gated_linear_pipeline_positive)
pipeline.add_branch(mock_gated_linear_pipeline_positive)
result = pipeline.run(value="abc", condition_text="positive condition")
assert result.text == ""
assert openai_mocker.call_count == 2

49
tests/test_documents.py Normal file
View File

@ -0,0 +1,49 @@
from haystack.schema import Document as HaystackDocument
from kotaemon.documents.base import Document, RetrievedDocument
def test_document_constructor_with_builtin_types():
for value in ["str", 1, {}, set(), [], tuple, None]:
doc = Document(value)
assert doc.text == (str(value) if value else "")
assert doc.content == value
assert bool(doc) == bool(value)
def test_document_constructor_with_document():
text = "Sample text"
doc1 = Document(text)
doc2 = Document(doc1)
assert doc2.text == doc1.text
assert doc2.content == doc1.content
def test_document_to_haystack_format():
text = "Sample text"
metadata = {"filename": "sample.txt"}
doc = Document(text, metadata=metadata)
haystack_doc = doc.to_haystack_format()
assert isinstance(haystack_doc, HaystackDocument)
assert haystack_doc.content == doc.text
assert haystack_doc.meta == metadata
def test_retrieved_document_default_values():
sample_text = "text"
retrieved_doc = RetrievedDocument(text=sample_text)
assert retrieved_doc.text == sample_text
assert retrieved_doc.score == 0.0
assert retrieved_doc.retrieval_metadata == {}
def test_retrieved_document_attributes():
sample_text = "text"
score = 0.8
metadata = {"source": "retrieval_system"}
retrieved_doc = RetrievedDocument(
text=sample_text, score=score, retrieval_metadata=metadata
)
assert retrieved_doc.text == sample_text
assert retrieved_doc.score == score
assert retrieved_doc.retrieval_metadata == metadata

View File

@ -14,8 +14,8 @@ def regex_extractor():
def test_run_document(regex_extractor): def test_run_document(regex_extractor):
document = Document(text="This is a test. 1 2 3") document = Document(text="This is a test. 1 2 3")
extracted_document = regex_extractor(document) extracted_document = regex_extractor(document)
extracted_texts = [each.text for each in extracted_document] assert extracted_document.text == "One"
assert extracted_texts == ["One", "Two", "Three"] assert extracted_document.matches == ["One", "Two", "Three"]
def test_is_document(regex_extractor): def test_is_document(regex_extractor):
@ -30,11 +30,13 @@ def test_is_batch(regex_extractor):
def test_run_raw(regex_extractor): def test_run_raw(regex_extractor):
output = regex_extractor("This is a test. 123") output = regex_extractor("This is a test. 123")
output = [each.text for each in output] assert output.text == "123"
assert output == ["123"] assert output.matches == ["123"]
def test_run_batch_raw(regex_extractor): def test_run_batch_raw(regex_extractor):
output = regex_extractor(["This is a test. 123", "456"]) output = regex_extractor(["This is a test. 123", "456"])
output = [[each.text for each in batch] for batch in output] extracted_text = [each.text for each in output]
assert output == [["123"], ["456"]] extracted_matches = [each.matches for each in output]
assert extracted_text == ["123", "456"]
assert extracted_matches == [["123"], ["456"]]

View File

@ -54,10 +54,7 @@ def test_run():
result = prompt() result = prompt()
assert ( assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One"
result.text
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
)
def test_set_method(): def test_set_method():