diff --git a/.gitignore b/.gitignore index a8af57d..2122691 100644 --- a/.gitignore +++ b/.gitignore @@ -458,3 +458,4 @@ logs/ S.gpg-agent* .vscode/settings.json +examples/example1/assets diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bb3084c..69878d8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -48,5 +48,10 @@ repos: hooks: - id: mypy additional_dependencies: [types-PyYAML==6.0.12.11, "types-requests"] - args: ["--check-untyped-defs", "--ignore-missing-imports"] + args: + [ + "--check-untyped-defs", + "--ignore-missing-imports", + "--new-type-inference", + ] exclude: "^templates/" diff --git a/knowledgehub/composite/__init__.py b/knowledgehub/composite/__init__.py new file mode 100644 index 0000000..fac7d70 --- /dev/null +++ b/knowledgehub/composite/__init__.py @@ -0,0 +1,9 @@ +from .branching import GatedBranchingPipeline, SimpleBranchingPipeline +from .linear import GatedLinearPipeline, SimpleLinearPipeline + +__all__ = [ + "SimpleLinearPipeline", + "GatedLinearPipeline", + "SimpleBranchingPipeline", + "GatedBranchingPipeline", +] diff --git a/knowledgehub/composite/branching.py b/knowledgehub/composite/branching.py new file mode 100644 index 0000000..fb0b192 --- /dev/null +++ b/knowledgehub/composite/branching.py @@ -0,0 +1,182 @@ +from typing import List, Optional + +from theflow import Param + +from kotaemon.base import BaseComponent +from kotaemon.composite.linear import GatedLinearPipeline +from kotaemon.documents.base import Document + + +class SimpleBranchingPipeline(BaseComponent): + """ + A simple branching pipeline for executing multiple branches. + + Attributes: + branches (List[BaseComponent]): The list of branches to be executed. + + Example Usage: + from kotaemon.composite import GatedLinearPipeline + from kotaemon.llms.chats.openai import AzureChatOpenAI + from kotaemon.post_processing.extractor import RegexExtractor + from kotaemon.prompt.base import BasePromptComponent + + def identity(x): + return x + + pipeline = SimpleBranchingPipeline() + llm = AzureChatOpenAI( + openai_api_base="your openai api base", + openai_api_key="your openai api key", + openai_api_version="your openai api version", + deployment_name="dummy-q2-gpt35", + temperature=0, + request_timeout=600, + ) + + for i in range(3): + pipeline.add_branch( + GatedLinearPipeline( + prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"), + condition=RegexExtractor(pattern=f"{i}"), + llm=llm, + post_processor=identity, + ) + ) + print(pipeline(condition_text="1")) + print(pipeline(condition_text="2")) + print(pipeline(condition_text="12")) + """ + + branches: List[BaseComponent] = Param(default_callback=lambda *_: []) + + def add_branch(self, component: BaseComponent): + """ + Add a new branch to the pipeline. + + Args: + component (BaseComponent): The branch component to be added. + """ + self.branches.append(component) + + def run(self, **prompt_kwargs): + """ + Execute the pipeline by running each branch and return the outputs as a list. + + Args: + **prompt_kwargs: Keyword arguments for the branches. + + Returns: + List: The outputs of each branch as a list. + """ + output = [] + for i, branch in enumerate(self.branches): + self._prepare_child(branch, name=f"branch-{i}") + output.append(branch(**prompt_kwargs)) + + return output + + +class GatedBranchingPipeline(SimpleBranchingPipeline): + """ + A simple gated branching pipeline for executing multiple branches based on a + condition. + + This class extends the SimpleBranchingPipeline class and adds the ability to execute + the branches until a branch returns a non-empty output based on a condition. + + Attributes: + branches (List[BaseComponent]): The list of branches to be executed. + + Example Usage: + from kotaemon.composite import GatedLinearPipeline + from kotaemon.llms.chats.openai import AzureChatOpenAI + from kotaemon.post_processing.extractor import RegexExtractor + from kotaemon.prompt.base import BasePromptComponent + + def identity(x): + return x + + pipeline = GatedBranchingPipeline() + llm = AzureChatOpenAI( + openai_api_base="your openai api base", + openai_api_key="your openai api key", + openai_api_version="your openai api version", + deployment_name="dummy-q2-gpt35", + temperature=0, + request_timeout=600, + ) + + for i in range(3): + pipeline.add_branch( + GatedLinearPipeline( + prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"), + condition=RegexExtractor(pattern=f"{i}"), + llm=llm, + post_processor=identity, + ) + ) + print(pipeline(condition_text="1")) + print(pipeline(condition_text="2")) + """ + + def run(self, *, condition_text: Optional[str] = None, **prompt_kwargs): + """ + Execute the pipeline by running each branch and return the output of the first + branch that returns a non-empty output based on the provided condition. + + Args: + condition_text (str): The condition text to evaluate for each branch. + Default to None. + **prompt_kwargs: Keyword arguments for the branches. + + Returns: + Union[OutputType, None]: The output of the first branch that satisfies the + condition, or None if no branch satisfies the condition. + + Raise: + ValueError: If condition_text is None + """ + if condition_text is None: + raise ValueError("`condition_text` must be provided.") + + for i, branch in enumerate(self.branches): + self._prepare_child(branch, name=f"branch-{i}") + output = branch(condition_text=condition_text, **prompt_kwargs) + if output: + return output + + return Document(None) + + +if __name__ == "__main__": + import dotenv + + from kotaemon.llms.chats.openai import AzureChatOpenAI + from kotaemon.post_processing.extractor import RegexExtractor + from kotaemon.prompt.base import BasePromptComponent + + def identity(x): + return x + + secrets = dotenv.dotenv_values(".env") + + pipeline = GatedBranchingPipeline() + llm = AzureChatOpenAI( + openai_api_base=secrets.get("OPENAI_API_BASE", ""), + openai_api_key=secrets.get("OPENAI_API_KEY", ""), + openai_api_version=secrets.get("OPENAI_API_VERSION", ""), + deployment_name="dummy-q2-gpt35", + temperature=0, + request_timeout=600, + ) + + for i in range(3): + pipeline.add_branch( + GatedLinearPipeline( + prompt=BasePromptComponent(template=f"what is {i} in Japanese ?"), + condition=RegexExtractor(pattern=f"{i}"), + llm=llm, + post_processor=identity, + ) + ) + pipeline(condition_text="1") diff --git a/knowledgehub/composite/linear.py b/knowledgehub/composite/linear.py new file mode 100644 index 0000000..288ba62 --- /dev/null +++ b/knowledgehub/composite/linear.py @@ -0,0 +1,153 @@ +from typing import Any, Callable, Optional, Union + +from kotaemon.base import BaseComponent +from kotaemon.documents.base import Document, IO_Type +from kotaemon.llms.chats.base import ChatLLM +from kotaemon.llms.completions.base import LLM +from kotaemon.prompt.base import BasePromptComponent + + +class SimpleLinearPipeline(BaseComponent): + """ + A simple pipeline for running a function with a prompt, a language model, and an + optional post-processor. + + Attributes: + prompt (BasePromptComponent): The prompt component used to generate the initial + input. + llm (Union[ChatLLM, LLM]): The language model component used to generate the + output. + post_processor (Union[BaseComponent, Callable[[IO_Type], IO_Type]]): An optional + post-processor component or function. + + Example Usage: + from kotaemon.llms.chats.openai import AzureChatOpenAI + from kotaemon.prompt.base import BasePromptComponent + + def identity(x): + return x + + llm = AzureChatOpenAI( + openai_api_base="your openai api base", + openai_api_key="your openai api key", + openai_api_version="your openai api version", + deployment_name="dummy-q2-gpt35", + temperature=0, + request_timeout=600, + ) + + pipeline = SimpleLinearPipeline( + prompt=BasePromptComponent(template="what is {word} in Japanese ?"), + llm=llm, + post_processor=identity, + ) + print(pipeline(word="lone")) + """ + + prompt: BasePromptComponent + llm: Union[ChatLLM, LLM] + post_processor: Union[BaseComponent, Callable[[IO_Type], IO_Type]] + + def run( + self, + *, + llm_kwargs: Optional[dict] = {}, + post_processor_kwargs: Optional[dict] = {}, + **prompt_kwargs, + ): + """ + Run the function with the given arguments and return the final output as a + Document object. + + Args: + llm_kwargs (dict): Keyword arguments for the llm call. + post_processor_kwargs (dict): Keyword arguments for the post_processor. + **prompt_kwargs: Keyword arguments for populating the prompt. + + Returns: + Document: The final output of the function as a Document object. + """ + prompt = self.prompt(**prompt_kwargs) + llm_output = self.llm(prompt.text, **llm_kwargs) + if self.post_processor is not None: + final_output = self.post_processor(llm_output, **post_processor_kwargs) + else: + final_output = llm_output + + return Document(final_output) + + +class GatedLinearPipeline(SimpleLinearPipeline): + """ + A pipeline that extends the SimpleLinearPipeline class and adds a condition + attribute. + + Attributes: + condition (Callable[[IO_Type], Any]): A callable function that represents the + condition. + + Example Usage: + from kotaemon.llms.chats.openai import AzureChatOpenAI + from kotaemon.post_processing.extractor import RegexExtractor + from kotaemon.prompt.base import BasePromptComponent + + def identity(x): + return x + + llm = AzureChatOpenAI( + openai_api_base="your openai api base", + openai_api_key="your openai api key", + openai_api_version="your openai api version", + deployment_name="dummy-q2-gpt35", + temperature=0, + request_timeout=600, + ) + + pipeline = GatedLinearPipeline( + prompt=BasePromptComponent(template="what is {word} in Japanese ?"), + condition=RegexExtractor(pattern="some pattern"), + llm=llm, + post_processor=identity, + ) + print(pipeline(condition_text="some pattern", word="lone")) + print(pipeline(condition_text="other pattern", word="lone")) + """ + + condition: Callable[[IO_Type], Any] + + def run( + self, + *, + condition_text: Optional[str] = None, + llm_kwargs: Optional[dict] = {}, + post_processor_kwargs: Optional[dict] = {}, + **prompt_kwargs, + ) -> Document: + """ + Run the pipeline with the given arguments and return the final output as a + Document object. + + Args: + condition_text (str): The condition text to evaluate. Default to None. + llm_kwargs (dict): Additional keyword arguments for the language model call. + post_processor_kwargs (dict): Additional keyword arguments for the + post-processor. + **prompt_kwargs: Keyword arguments for populating the prompt. + + Returns: + Document: The final output of the pipeline as a Document object. + + Raises: + ValueError: If condition_text is None + """ + if condition_text is None: + raise ValueError("`condition_text` must be provided") + + if self.condition(condition_text): + return super().run( + llm_kwargs=llm_kwargs, + post_processor_kwargs=post_processor_kwargs, + **prompt_kwargs, + ) + + return Document(None) diff --git a/knowledgehub/documents/base.py b/knowledgehub/documents/base.py index 579031b..7044a12 100644 --- a/knowledgehub/documents/base.py +++ b/knowledgehub/documents/base.py @@ -1,12 +1,43 @@ +from typing import Any, Optional + from haystack.schema import Document as HaystackDocument from llama_index.bridge.pydantic import Field from llama_index.schema import Document as BaseDocument +from pyparsing import TypeVar +IO_Type = TypeVar("IO_Type", "Document", str) SAMPLE_TEXT = "A sample Document from kotaemon" class Document(BaseDocument): - """Base document class, mostly inherited from Document class from llama-index""" + """ + Base document class, mostly inherited from Document class from llama-index. + + This class accept one positional argument `content` of an arbitrary type, which will + store the raw content of the document. If specified, the class will use + `content` to initialize the base llama_index class. + """ + + content: Any + + def __init__(self, content: Optional[Any] = None, *args, **kwargs): + if content is None: + if kwargs.get("text", None) is not None: + kwargs["content"] = kwargs["text"] + elif kwargs.get("embedding", None) is not None: + kwargs["content"] = kwargs["embedding"] + elif isinstance(content, Document): + kwargs = content.dict() + else: + kwargs["content"] = content + if content: + kwargs["text"] = str(content) + else: + kwargs["text"] = "" + super().__init__(*args, **kwargs) + + def __bool__(self): + return bool(self.content) @classmethod def example(cls) -> "Document": @@ -23,7 +54,7 @@ class Document(BaseDocument): return HaystackDocument(content=text, meta=metadata) def __str__(self): - return self.text + return str(self.content) class RetrievedDocument(Document): diff --git a/knowledgehub/post_processing/extractor.py b/knowledgehub/post_processing/extractor.py index c60077a..ee5633b 100644 --- a/knowledgehub/post_processing/extractor.py +++ b/knowledgehub/post_processing/extractor.py @@ -1,22 +1,42 @@ import re -from typing import Dict, List +from typing import Callable, Dict, List, Union + +from theflow import Param from kotaemon.base import BaseComponent from kotaemon.documents.base import Document +class ExtractorOutput(Document): + """ + Represents the output of an extractor. + """ + + matches: List[str] + + class RegexExtractor(BaseComponent): """ Simple class for extracting text from a document using a regex pattern. Args: - pattern (str): The regex pattern to use. + pattern (List[str]): The regex pattern(s) to use. output_map (dict, optional): A mapping from extracted text to the desired output. Defaults to None. """ - pattern: str - output_map: Dict[str, str] = {} + class Config: + middleware_switches = {"theflow.middleware.CachingMiddleware": False} + + pattern: List[str] + output_map: Union[Dict[str, str], Callable[[str], str]] = Param( + default_callback=lambda *_: {} + ) + + def __init__(self, pattern: Union[str, List[str]], **kwargs): + if isinstance(pattern, str): + pattern = [pattern] + super().__init__(pattern=pattern, **kwargs) @staticmethod def run_raw_static(pattern: str, text: str) -> List[str]: @@ -50,28 +70,34 @@ class RegexExtractor(BaseComponent): if not output_map: return text - return str(output_map.get(text, text)) + if isinstance(output_map, dict): + return output_map.get(text, text) - def run_raw(self, text: str) -> List[Document]: + return output_map(text) + + def run_raw(self, text: str) -> ExtractorOutput: """ - Runs the raw text through the static pattern and output mapping, returning a - list of strings. + Matches the raw text against the pattern and rans the output mapping, returning + an instance of ExtractorOutput. Args: text (str): The raw text to be processed. Returns: - List[str]: The processed output as a list of strings. + ExtractorOutput: The processed output as a list of ExtractorOutput. """ - output = self.run_raw_static(self.pattern, text) + output = sum( + [self.run_raw_static(p, text) for p in self.pattern], [] + ) # type: List[str] output = [self.map_output(text, self.output_map) for text in output] - return [ - Document(text=text, metadata={"origin": "RegexExtractor"}) - for text in output - ] + return ExtractorOutput( + text=output[0] if output else "", + matches=output, + metadata={"origin": "RegexExtractor"}, + ) - def run_batch_raw(self, text_batch: List[str]) -> List[List[Document]]: + def run_batch_raw(self, text_batch: List[str]) -> List[ExtractorOutput]: """ Runs a batch of raw text inputs through the `run_raw()` method and returns the output for each input. @@ -80,29 +106,28 @@ class RegexExtractor(BaseComponent): text_batch (List[str]): A list of raw text inputs to process. Returns: - List[List[str]]: A list of lists containing the output for each input in the + List[ExtractorOutput]: A list containing the output for each input in the batch. """ batch_output = [self.run_raw(each_text) for each_text in text_batch] return batch_output - def run_document(self, document: Document) -> List[Document]: + def run_document(self, document: Document) -> ExtractorOutput: """ - Run the document through the regex extractor and return a list of extracted - documents. + Run the document through the regex extractor and return an extracted document. Args: document (Document): The input document. Returns: - List[Document]: A list of extracted documents. + ExtractorOutput: The extracted content. """ return self.run_raw(document.text) def run_batch_document( self, document_batch: List[Document] - ) -> List[List[Document]]: + ) -> List[ExtractorOutput]: """ Runs a batch of documents through the `run_document` function and returns the output for each document. @@ -113,15 +138,15 @@ class RegexExtractor(BaseComponent): batch of documents to process. Returns: - List[List[Document]]: A list of lists where each inner list contains the - output Document for each input Document in the batch. + List[ExtractorOutput]: A list contains the output ExtractorOutput for each + input Document in the batch. Example: document1 = Document(...) document2 = Document(...) document_batch = [document1, document2] batch_output = self.run_batch_document(document_batch) - # batch_output will be [[output1_document1, ...], [output1_document2, ...]] + # batch_output will be [output1_document1, output1_document2] """ batch_output = [ @@ -162,3 +187,22 @@ class RegexExtractor(BaseComponent): return True return False + + +class FirstMatchRegexExtractor(RegexExtractor): + pattern: List[str] + + def run_raw(self, text: str) -> ExtractorOutput: + for p in self.pattern: + output = self.run_raw_static(p, text) + if output: + output = [self.map_output(text, self.output_map) for text in output] + return ExtractorOutput( + text=output[0], + matches=output, + metadata={"origin": "FirstMatchRegexExtractor"}, + ) + + return ExtractorOutput( + text=None, matches=[], metadata={"origin": "FirstMatchRegexExtractor"} + ) diff --git a/knowledgehub/prompt/base.py b/knowledgehub/prompt/base.py index caee3b0..48b370c 100644 --- a/knowledgehub/prompt/base.py +++ b/knowledgehub/prompt/base.py @@ -15,6 +15,9 @@ class BasePromptComponent(BaseComponent): given template. """ + class Config: + middleware_switches = {"theflow.middleware.CachingMiddleware": False} + def __init__(self, template: Union[str, PromptTemplate], **kwargs): super().__init__() self.template = ( diff --git a/tests/test_composite.py b/tests/test_composite.py new file mode 100644 index 0000000..66f7e94 --- /dev/null +++ b/tests/test_composite.py @@ -0,0 +1,141 @@ +import pytest + +from kotaemon.composite import ( + GatedBranchingPipeline, + GatedLinearPipeline, + SimpleBranchingPipeline, + SimpleLinearPipeline, +) +from kotaemon.llms.chats.openai import AzureChatOpenAI +from kotaemon.post_processing.extractor import RegexExtractor +from kotaemon.prompt.base import BasePromptComponent + + +@pytest.fixture +def mock_llm(): + return AzureChatOpenAI( + openai_api_base="OPENAI_API_BASE", + openai_api_key="OPENAI_API_KEY", + openai_api_version="OPENAI_API_VERSION", + deployment_name="dummy-q2-gpt35", + temperature=0, + request_timeout=600, + ) + + +@pytest.fixture +def mock_post_processor(): + return RegexExtractor(pattern=r"\d+") + + +@pytest.fixture +def mock_prompt(): + return BasePromptComponent(template="Test prompt {value}") + + +@pytest.fixture +def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor): + return SimpleLinearPipeline( + prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor + ) + + +@pytest.fixture +def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor): + return GatedLinearPipeline( + prompt=mock_prompt, + llm=mock_llm, + post_processor=mock_post_processor, + condition=RegexExtractor(pattern="positive"), + ) + + +@pytest.fixture +def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor): + return GatedLinearPipeline( + prompt=mock_prompt, + llm=mock_llm, + post_processor=mock_post_processor, + condition=RegexExtractor(pattern="negative"), + ) + + +def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline): + openai_mocker = mocker.patch.object( + AzureChatOpenAI, "run", return_value="This is a test 123" + ) + + result = mock_simple_linear_pipeline.run(value="abc") + + assert result.text == "123" + assert openai_mocker.call_count == 1 + + +def test_gated_linear_pipeline_run_positive( + mocker, mock_gated_linear_pipeline_positive +): + openai_mocker = mocker.patch.object( + AzureChatOpenAI, "run", return_value="This is a test 123." + ) + + result = mock_gated_linear_pipeline_positive.run( + value="abc", condition_text="positive condition" + ) + + assert result.text == "123" + assert openai_mocker.call_count == 1 + + +def test_gated_linear_pipeline_run_negative( + mocker, mock_gated_linear_pipeline_positive +): + openai_mocker = mocker.patch.object( + AzureChatOpenAI, "run", return_value="This is a test 123." + ) + + result = mock_gated_linear_pipeline_positive.run( + value="abc", condition_text="negative condition" + ) + + assert result.content is None + assert openai_mocker.call_count == 0 + + +def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline): + openai_mocker = mocker.patch.object( + AzureChatOpenAI, + "run", + side_effect=[ + "This is a test 123.", + "a quick brown fox", + "jumps over the lazy dog 456", + ], + ) + pipeline = SimpleBranchingPipeline() + for _ in range(3): + pipeline.add_branch(mock_simple_linear_pipeline) + + result = pipeline.run(value="abc") + texts = [each.text for each in result] + + assert len(result) == 3 + assert texts == ["123", "", "456"] + assert openai_mocker.call_count == 3 + + +def test_simple_gated_branching_pipeline_run( + mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative +): + openai_mocker = mocker.patch.object( + AzureChatOpenAI, "run", return_value="a quick brown fox" + ) + pipeline = GatedBranchingPipeline() + + pipeline.add_branch(mock_gated_linear_pipeline_negative) + pipeline.add_branch(mock_gated_linear_pipeline_positive) + pipeline.add_branch(mock_gated_linear_pipeline_positive) + + result = pipeline.run(value="abc", condition_text="positive condition") + + assert result.text == "" + assert openai_mocker.call_count == 2 diff --git a/tests/test_documents.py b/tests/test_documents.py new file mode 100644 index 0000000..a76464d --- /dev/null +++ b/tests/test_documents.py @@ -0,0 +1,49 @@ +from haystack.schema import Document as HaystackDocument + +from kotaemon.documents.base import Document, RetrievedDocument + + +def test_document_constructor_with_builtin_types(): + for value in ["str", 1, {}, set(), [], tuple, None]: + doc = Document(value) + assert doc.text == (str(value) if value else "") + assert doc.content == value + assert bool(doc) == bool(value) + + +def test_document_constructor_with_document(): + text = "Sample text" + doc1 = Document(text) + doc2 = Document(doc1) + assert doc2.text == doc1.text + assert doc2.content == doc1.content + + +def test_document_to_haystack_format(): + text = "Sample text" + metadata = {"filename": "sample.txt"} + doc = Document(text, metadata=metadata) + haystack_doc = doc.to_haystack_format() + assert isinstance(haystack_doc, HaystackDocument) + assert haystack_doc.content == doc.text + assert haystack_doc.meta == metadata + + +def test_retrieved_document_default_values(): + sample_text = "text" + retrieved_doc = RetrievedDocument(text=sample_text) + assert retrieved_doc.text == sample_text + assert retrieved_doc.score == 0.0 + assert retrieved_doc.retrieval_metadata == {} + + +def test_retrieved_document_attributes(): + sample_text = "text" + score = 0.8 + metadata = {"source": "retrieval_system"} + retrieved_doc = RetrievedDocument( + text=sample_text, score=score, retrieval_metadata=metadata + ) + assert retrieved_doc.text == sample_text + assert retrieved_doc.score == score + assert retrieved_doc.retrieval_metadata == metadata diff --git a/tests/test_post_processing.py b/tests/test_post_processing.py index bd27ad7..aefa128 100644 --- a/tests/test_post_processing.py +++ b/tests/test_post_processing.py @@ -14,8 +14,8 @@ def regex_extractor(): def test_run_document(regex_extractor): document = Document(text="This is a test. 1 2 3") extracted_document = regex_extractor(document) - extracted_texts = [each.text for each in extracted_document] - assert extracted_texts == ["One", "Two", "Three"] + assert extracted_document.text == "One" + assert extracted_document.matches == ["One", "Two", "Three"] def test_is_document(regex_extractor): @@ -30,11 +30,13 @@ def test_is_batch(regex_extractor): def test_run_raw(regex_extractor): output = regex_extractor("This is a test. 123") - output = [each.text for each in output] - assert output == ["123"] + assert output.text == "123" + assert output.matches == ["123"] def test_run_batch_raw(regex_extractor): output = regex_extractor(["This is a test. 123", "456"]) - output = [[each.text for each in batch] for batch in output] - assert output == [["123"], ["456"]] + extracted_text = [each.text for each in output] + extracted_matches = [each.matches for each in output] + assert extracted_text == ["123", "456"] + assert extracted_matches == [["123"], ["456"]] diff --git a/tests/test_prompt.py b/tests/test_prompt.py index f71d5fe..51c5154 100644 --- a/tests/test_prompt.py +++ b/tests/test_prompt.py @@ -54,10 +54,7 @@ def test_run(): result = prompt() - assert ( - result.text - == "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']" - ) + assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One" def test_set_method():