* Adopt Example1 disclaimer pipeline * Update Document class * Add composite components * Modify Extractor behaviours
142 lines
3.9 KiB
Python
142 lines
3.9 KiB
Python
import pytest
|
|
|
|
from kotaemon.composite import (
|
|
GatedBranchingPipeline,
|
|
GatedLinearPipeline,
|
|
SimpleBranchingPipeline,
|
|
SimpleLinearPipeline,
|
|
)
|
|
from kotaemon.llms.chats.openai import AzureChatOpenAI
|
|
from kotaemon.post_processing.extractor import RegexExtractor
|
|
from kotaemon.prompt.base import BasePromptComponent
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_llm():
|
|
return AzureChatOpenAI(
|
|
openai_api_base="OPENAI_API_BASE",
|
|
openai_api_key="OPENAI_API_KEY",
|
|
openai_api_version="OPENAI_API_VERSION",
|
|
deployment_name="dummy-q2-gpt35",
|
|
temperature=0,
|
|
request_timeout=600,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_post_processor():
|
|
return RegexExtractor(pattern=r"\d+")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_prompt():
|
|
return BasePromptComponent(template="Test prompt {value}")
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_simple_linear_pipeline(mock_prompt, mock_llm, mock_post_processor):
|
|
return SimpleLinearPipeline(
|
|
prompt=mock_prompt, llm=mock_llm, post_processor=mock_post_processor
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_gated_linear_pipeline_positive(mock_prompt, mock_llm, mock_post_processor):
|
|
return GatedLinearPipeline(
|
|
prompt=mock_prompt,
|
|
llm=mock_llm,
|
|
post_processor=mock_post_processor,
|
|
condition=RegexExtractor(pattern="positive"),
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_gated_linear_pipeline_negative(mock_prompt, mock_llm, mock_post_processor):
|
|
return GatedLinearPipeline(
|
|
prompt=mock_prompt,
|
|
llm=mock_llm,
|
|
post_processor=mock_post_processor,
|
|
condition=RegexExtractor(pattern="negative"),
|
|
)
|
|
|
|
|
|
def test_simple_linear_pipeline_run(mocker, mock_simple_linear_pipeline):
|
|
openai_mocker = mocker.patch.object(
|
|
AzureChatOpenAI, "run", return_value="This is a test 123"
|
|
)
|
|
|
|
result = mock_simple_linear_pipeline.run(value="abc")
|
|
|
|
assert result.text == "123"
|
|
assert openai_mocker.call_count == 1
|
|
|
|
|
|
def test_gated_linear_pipeline_run_positive(
|
|
mocker, mock_gated_linear_pipeline_positive
|
|
):
|
|
openai_mocker = mocker.patch.object(
|
|
AzureChatOpenAI, "run", return_value="This is a test 123."
|
|
)
|
|
|
|
result = mock_gated_linear_pipeline_positive.run(
|
|
value="abc", condition_text="positive condition"
|
|
)
|
|
|
|
assert result.text == "123"
|
|
assert openai_mocker.call_count == 1
|
|
|
|
|
|
def test_gated_linear_pipeline_run_negative(
|
|
mocker, mock_gated_linear_pipeline_positive
|
|
):
|
|
openai_mocker = mocker.patch.object(
|
|
AzureChatOpenAI, "run", return_value="This is a test 123."
|
|
)
|
|
|
|
result = mock_gated_linear_pipeline_positive.run(
|
|
value="abc", condition_text="negative condition"
|
|
)
|
|
|
|
assert result.content is None
|
|
assert openai_mocker.call_count == 0
|
|
|
|
|
|
def test_simple_branching_pipeline_run(mocker, mock_simple_linear_pipeline):
|
|
openai_mocker = mocker.patch.object(
|
|
AzureChatOpenAI,
|
|
"run",
|
|
side_effect=[
|
|
"This is a test 123.",
|
|
"a quick brown fox",
|
|
"jumps over the lazy dog 456",
|
|
],
|
|
)
|
|
pipeline = SimpleBranchingPipeline()
|
|
for _ in range(3):
|
|
pipeline.add_branch(mock_simple_linear_pipeline)
|
|
|
|
result = pipeline.run(value="abc")
|
|
texts = [each.text for each in result]
|
|
|
|
assert len(result) == 3
|
|
assert texts == ["123", "", "456"]
|
|
assert openai_mocker.call_count == 3
|
|
|
|
|
|
def test_simple_gated_branching_pipeline_run(
|
|
mocker, mock_gated_linear_pipeline_positive, mock_gated_linear_pipeline_negative
|
|
):
|
|
openai_mocker = mocker.patch.object(
|
|
AzureChatOpenAI, "run", return_value="a quick brown fox"
|
|
)
|
|
pipeline = GatedBranchingPipeline()
|
|
|
|
pipeline.add_branch(mock_gated_linear_pipeline_negative)
|
|
pipeline.add_branch(mock_gated_linear_pipeline_positive)
|
|
pipeline.add_branch(mock_gated_linear_pipeline_positive)
|
|
|
|
result = pipeline.run(value="abc", condition_text="positive condition")
|
|
|
|
assert result.text == ""
|
|
assert openai_mocker.call_count == 2
|