[AUR-395] Adopt Example1 disclaimer pipeline (#42)

* Adopt Example1 disclaimer pipeline
* Update Document class
* Add composite components
* Modify Extractor behaviours
This commit is contained in:
ian_Cin
2023-10-10 15:42:48 +07:00
committed by GitHub
parent 79cc60e6a2
commit 84f1fa8cbd
12 changed files with 654 additions and 37 deletions

141
tests/test_composite.py Normal file
View File

@@ -0,0 +1,141 @@
import pytest
from kotaemon.composite import (
GatedBranchingPipeline,
GatedLinearPipeline,
SimpleBranchingPipeline,
SimpleLinearPipeline,
)
from kotaemon.llms.chats.openai import AzureChatOpenAI
from kotaemon.post_processing.extractor import RegexExtractor
from kotaemon.prompt.base import BasePromptComponent
@pytest.fixture
def mock_llm():
return AzureChatOpenAI(
openai_api_base="OPENAI_API_BASE",
openai_api_key="OPENAI_API_KEY",
openai_api_version="OPENAI_API_VERSION",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=600,
)
@pytest.fixture
def mock_post_processor():
return RegexExtractor(pattern=r"\d+")
@pytest.fixture
def mock_prompt():
return BasePromptComponent(template="Test prompt {value}")
@pytest.fixture
def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor):
return SimpleLinearPipeline(
prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor
)
@pytest.fixture
def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor):
return GatedLinearPipeline(
prompt=mock_prompt,
llm=mock_llm,
post_processor=mock_post_processor,
condition=RegexExtractor(pattern="positive"),
)
@pytest.fixture
def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor):
return GatedLinearPipeline(
prompt=mock_prompt,
llm=mock_llm,
post_processor=mock_post_processor,
condition=RegexExtractor(pattern="negative"),
)
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="This is a test 123"
)
result = mock_simple_linear_pipeline.run(value="abc")
assert result.text == "123"
assert openai_mocker.call_count == 1
def test_gated_linear_pipeline_run_positive(
mocker, mock_gated_linear_pipeline_positive
):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="This is a test 123."
)
result = mock_gated_linear_pipeline_positive.run(
value="abc", condition_text="positive condition"
)
assert result.text == "123"
assert openai_mocker.call_count == 1
def test_gated_linear_pipeline_run_negative(
mocker, mock_gated_linear_pipeline_positive
):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="This is a test 123."
)
result = mock_gated_linear_pipeline_positive.run(
value="abc", condition_text="negative condition"
)
assert result.content is None
assert openai_mocker.call_count == 0
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
openai_mocker = mocker.patch.object(
AzureChatOpenAI,
"run",
side_effect=[
"This is a test 123.",
"a quick brown fox",
"jumps over the lazy dog 456",
],
)
pipeline = SimpleBranchingPipeline()
for _ in range(3):
pipeline.add_branch(mock_simple_linear_pipeline)
result = pipeline.run(value="abc")
texts = [each.text for each in result]
assert len(result) == 3
assert texts == ["123", "", "456"]
assert openai_mocker.call_count == 3
def test_simple_gated_branching_pipeline_run(
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
):
openai_mocker = mocker.patch.object(
AzureChatOpenAI, "run", return_value="a quick brown fox"
)
pipeline = GatedBranchingPipeline()
pipeline.add_branch(mock_gated_linear_pipeline_negative)
pipeline.add_branch(mock_gated_linear_pipeline_positive)
pipeline.add_branch(mock_gated_linear_pipeline_positive)
result = pipeline.run(value="abc", condition_text="positive condition")
assert result.text == ""
assert openai_mocker.call_count == 2

49
tests/test_documents.py Normal file
View File

@@ -0,0 +1,49 @@
from haystack.schema import Document as HaystackDocument
from kotaemon.documents.base import Document, RetrievedDocument
def test_document_constructor_with_builtin_types():
for value in ["str", 1, {}, set(), [], tuple, None]:
doc = Document(value)
assert doc.text == (str(value) if value else "")
assert doc.content == value
assert bool(doc) == bool(value)
def test_document_constructor_with_document():
text = "Sample text"
doc1 = Document(text)
doc2 = Document(doc1)
assert doc2.text == doc1.text
assert doc2.content == doc1.content
def test_document_to_haystack_format():
text = "Sample text"
metadata = {"filename": "sample.txt"}
doc = Document(text, metadata=metadata)
haystack_doc = doc.to_haystack_format()
assert isinstance(haystack_doc, HaystackDocument)
assert haystack_doc.content == doc.text
assert haystack_doc.meta == metadata
def test_retrieved_document_default_values():
sample_text = "text"
retrieved_doc = RetrievedDocument(text=sample_text)
assert retrieved_doc.text == sample_text
assert retrieved_doc.score == 0.0
assert retrieved_doc.retrieval_metadata == {}
def test_retrieved_document_attributes():
sample_text = "text"
score = 0.8
metadata = {"source": "retrieval_system"}
retrieved_doc = RetrievedDocument(
text=sample_text, score=score, retrieval_metadata=metadata
)
assert retrieved_doc.text == sample_text
assert retrieved_doc.score == score
assert retrieved_doc.retrieval_metadata == metadata

View File

@@ -14,8 +14,8 @@ def regex_extractor():
def test_run_document(regex_extractor):
document = Document(text="This is a test. 1 2 3")
extracted_document = regex_extractor(document)
extracted_texts = [each.text for each in extracted_document]
assert extracted_texts == ["One", "Two", "Three"]
assert extracted_document.text == "One"
assert extracted_document.matches == ["One", "Two", "Three"]
def test_is_document(regex_extractor):
@@ -30,11 +30,13 @@ def test_is_batch(regex_extractor):
def test_run_raw(regex_extractor):
output = regex_extractor("This is a test. 123")
output = [each.text for each in output]
assert output == ["123"]
assert output.text == "123"
assert output.matches == ["123"]
def test_run_batch_raw(regex_extractor):
output = regex_extractor(["This is a test. 123", "456"])
output = [[each.text for each in batch] for batch in output]
assert output == [["123"], ["456"]]
extracted_text = [each.text for each in output]
extracted_matches = [each.matches for each in output]
assert extracted_text == ["123", "456"]
assert extracted_matches == [["123"], ["456"]]

View File

@@ -54,10 +54,7 @@ def test_run():
result = prompt()
assert (
result.text
== "str = Alice, int = 30, doc = Helloo, Alice!, comp = ['One', 'Two', 'Three']"
)
assert result.text == "str = Alice, int = 30, doc = Helloo, Alice!, comp = One"
def test_set_method():