[AUR-338, AUR-406, AUR-407] Export pipeline to config for PromptUI. Construct PromptUI dynamically based on config. (#16)

From pipeline > config > UI. Provide example project for promptui

- Pipeline to config: `kotaemon.contribs.promptui.config.export_pipeline_to_config`. The config follows schema specified in this document: https://cinnamon-ai.atlassian.net/wiki/spaces/ATM/pages/2748711193/Technical+Detail. Note: this implementation exclude the logs, which will be handled in AUR-408.
- Config to UI: `kotaemon.contribs.promptui.build_from_yaml`
- Example project is located at `examples/promptui/`
This commit is contained in:
Nguyen Trung Duc (john) 2023-09-21 14:27:23 +07:00 committed by GitHub
parent c329c4c03f
commit c6dd01e820
18 changed files with 503 additions and 46 deletions

Binary file not shown.

View File

@ -1 +1 @@
credentials.txt:1e17fa46dd8353b5ded588b32983ac7d800e70fd16bc5831663b9aaefc409011 credentials.txt:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5

View File

@ -25,7 +25,7 @@ repos:
rev: 4.0.1 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8
args: ["--max-line-length", "88"] args: ["--max-line-length", "88", "--extend-ignore", "E203"]
- repo: https://github.com/myint/autoflake - repo: https://github.com/myint/autoflake
rev: v1.4 rev: v1.4
hooks: hooks:
@ -47,4 +47,5 @@ repos:
rev: "v1.5.1" rev: "v1.5.1"
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: [types-PyYAML==6.0.12.11]
args: ["--check-untyped-defs", "--ignore-missing-imports"] args: ["--check-untyped-defs", "--ignore-missing-imports"]

Binary file not shown.

View File

@ -0,0 +1,20 @@
import gradio as gr
COMPONENTS_CLASS = {
"text": gr.components.Textbox,
"checkbox": gr.components.CheckboxGroup,
"dropdown": gr.components.Dropdown,
"file": gr.components.File,
"image": gr.components.Image,
"number": gr.components.Number,
"radio": gr.components.Radio,
"slider": gr.components.Slider,
}
SUPPORTED_COMPONENTS = set(COMPONENTS_CLASS.keys())
DEFAULT_COMPONENT_BY_TYPES = {
"str": "text",
"bool": "checkbox",
"int": "number",
"float": "number",
"list": "dropdown",
}

View File

@ -1 +1,132 @@
"""Get config from Pipeline""" """Get config from Pipeline"""
import inspect
from pathlib import Path
from typing import Any, Dict, Optional, Type, Union
import yaml
from ...base import BaseComponent
from .base import DEFAULT_COMPONENT_BY_TYPES
def config_from_value(value: Any) -> dict:
"""Get the config from default value
Args:
value (Any): default value
Returns:
dict: config
"""
component = DEFAULT_COMPONENT_BY_TYPES.get(type(value).__name__, "text")
return {
"component": component,
"params": {
"value": value,
},
}
def handle_param(param: dict) -> dict:
"""Convert param definition into promptui-compliant config
Supported gradio's UI components are (https://www.gradio.app/docs/components)
- CheckBoxGroup: list (multi select)
- DropDown: list (single select)
- File
- Image
- Number: int / float
- Radio: list (single select)
- Slider: int / float
- TextBox: str
"""
params = {}
default = param.get("default", None)
if isinstance(default, str) and default.startswith("{{") and default.endswith("}}"):
default = None
if default is not None:
params["value"] = default
type_: str = type(default).__name__ if default is not None else ""
ui_component = DEFAULT_COMPONENT_BY_TYPES.get(type_, "text")
return {
"component": ui_component,
"params": params,
}
def handle_node(node: dict) -> dict:
"""Convert node definition into promptui-compliant config"""
config = {}
for name, param_def in node.get("params", {}).items():
if isinstance(param_def["default_callback"], str):
continue
config[name] = handle_param(param_def)
for name, node_def in node.get("nodes", {}).items():
if isinstance(node_def["default_callback"], str):
continue
for key, value in handle_node(node_def["default"]).items():
config[f"{name}.{key}"] = value
for key, value in node_def["default_kwargs"].items():
config[f"{name}.{key}"] = config_from_value(value)
return config
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
"""Get the input from the pipeline"""
if not hasattr(pipeline, "run_raw"):
return {}
signature = inspect.signature(pipeline.run_raw)
inputs: Dict[str, Dict] = {}
for name, param in signature.parameters.items():
if name in ["self", "args", "kwargs"]:
continue
input_def: Dict[str, Optional[Any]] = {"component": "text"}
default = param.default
if default is param.empty:
inputs[name] = input_def
continue
params = {}
params["value"] = default
type_ = type(default).__name__ if default is not None else None
ui_component = None
if type_ is not None:
ui_component = "text"
input_def["component"] = ui_component
input_def["params"] = params
inputs[name] = input_def
return inputs
def export_pipeline_to_config(
pipeline: Union[BaseComponent, Type[BaseComponent]],
path: Optional[str] = None,
) -> dict:
"""Export a pipeline to a promptui-compliant config dict"""
if inspect.isclass(pipeline):
pipeline = pipeline()
pipeline_def = pipeline.describe()
config = {
f"{pipeline.__module__}.{pipeline.__class__.__name__}": {
"params": handle_node(pipeline_def),
"inputs": handle_input(pipeline),
"outputs": [{"step": ".", "component": "text"}],
}
}
if path is not None:
old_config = config
if Path(path).is_file():
with open(path) as f:
old_config = yaml.safe_load(f)
old_config.update(config)
with open(path, "w") as f:
yaml.safe_dump(old_config, f)
return config

View File

@ -1,6 +1,151 @@
from typing import Union
import gradio as gr
import yaml
from theflow.utils.modules import import_dotted_string
from kotaemon.contribs.promptui.base import COMPONENTS_CLASS, SUPPORTED_COMPONENTS
USAGE_INSTRUCTION = """In case of errors, you can:
- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon
- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization"""
def get_component(component_def: dict) -> gr.components.Component:
"""Get the component based on component definition"""
component_cls = None
if "component" in component_def:
component = component_def["component"]
if component not in SUPPORTED_COMPONENTS:
raise ValueError(
f"Unsupported UI component: {component}. "
f"Must be one of {SUPPORTED_COMPONENTS}"
)
component_cls = COMPONENTS_CLASS[component]
else:
raise ValueError(
f"Cannot decide the component from {component_def}. "
"Please specify `component` with 1 of the following "
f"values: {SUPPORTED_COMPONENTS}"
)
return component_cls(**component_def.get("params", {}))
def construct_ui(config, func_run, func_export) -> gr.Blocks:
"""Create UI from config file. Execute the UI from config file """Create UI from config file. Execute the UI from config file
- Can do now: Log from stdout to UI - Can do now: Log from stdout to UI
- In the future, we can provide some hooks and callbacks to let developers better - In the future, we can provide some hooks and callbacks to let developers better
fine-tune the UI behavior. fine-tune the UI behavior.
""" """
inputs, outputs, params = [], [], []
for name, component_def in config.get("inputs", {}).items():
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = True
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = name # type: ignore
inputs.append(component)
for name, component_def in config.get("params", {}).items():
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = True
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = name # type: ignore
params.append(component)
for idx, component_def in enumerate(config.get("outputs", [])):
if "params" not in component_def:
component_def["params"] = {}
component_def["params"]["interactive"] = False
component = get_component(component_def)
if hasattr(component, "label") and not component.label: # type: ignore
component.label = f"Output {idx}"
outputs.append(component)
temp = gr.Tab
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
with gr.Accordion(label="Usage", open=False):
gr.Markdown(USAGE_INSTRUCTION)
with gr.Row():
run_btn = gr.Button("Run")
run_btn.click(func_run, inputs=inputs + params, outputs=outputs)
export_btn = gr.Button("Export")
export_btn.click(func_export, inputs=None, outputs=None)
with gr.Row():
with gr.Column():
with temp("Inputs"):
for component in inputs:
component.render()
with temp("Params"):
for component in params:
component.render()
with gr.Column():
for component in outputs:
component.render()
return demo
def build_pipeline_ui(config: dict, pipeline_def):
"""Build a tab from config file"""
inputs_name = list(config.get("inputs", {}).keys())
params_name = list(config.get("params", {}).keys())
outputs_def = config.get("outputs", [])
def run_func(*args):
inputs = {
name: value for name, value in zip(inputs_name, args[: len(inputs_name)])
}
params = {
name: value for name, value in zip(params_name, args[len(inputs_name) :])
}
pipeline = pipeline_def()
pipeline.set(params)
pipeline(**inputs)
if outputs_def:
outputs = []
for output_def in outputs_def:
output = pipeline.last_run.logs(output_def["step"])
if "item" in output_def:
output = output[output_def["item"]]
outputs.append(output)
return outputs
# TODO: export_func is None for now
return construct_ui(config, run_func, None)
def build_from_dict(config: Union[str, dict]):
"""Build a full UI from YAML config file"""
if isinstance(config, str):
with open(config) as f:
config_dict: dict = yaml.safe_load(f)
elif isinstance(config, dict):
config_dict = config
else:
raise ValueError(
f"config must be either a yaml path or a dict, got {type(config)}"
)
demos = []
for key, value in config_dict.items():
pipeline_def = import_dotted_string(key, safe=False)
demos.append(build_pipeline_ui(value, pipeline_def))
if len(demos) == 1:
demo = demos[0]
else:
demo = gr.TabbedInterface(demos, list(config_dict.keys()))
return demo

View File

@ -1,4 +1,4 @@
from .base import BaseDocumentStore from .base import BaseDocumentStore
from .simple import InMemoryDocumentStore from .in_memory import InMemoryDocumentStore
__all__ = ["BaseDocumentStore", "InMemoryDocumentStore"] __all__ = ["BaseDocumentStore", "InMemoryDocumentStore"]

View File

@ -10,7 +10,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
"""Simple memory document store that store document in a dictionary""" """Simple memory document store that store document in a dictionary"""
def __init__(self): def __init__(self):
self.store = {} self._store = {}
def add( def add(
self, self,
@ -32,20 +32,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
docs = [docs] docs = [docs]
for doc_id, doc in zip(doc_ids, docs): for doc_id, doc in zip(doc_ids, docs):
if doc_id in self.store and not exist_ok: if doc_id in self._store and not exist_ok:
raise ValueError(f"Document with id {doc_id} already exist") raise ValueError(f"Document with id {doc_id} already exist")
self.store[doc_id] = doc self._store[doc_id] = doc
def get(self, ids: Union[List[str], str]) -> List[Document]: def get(self, ids: Union[List[str], str]) -> List[Document]:
"""Get document by id""" """Get document by id"""
if not isinstance(ids, list): if not isinstance(ids, list):
ids = [ids] ids = [ids]
return [self.store[doc_id] for doc_id in ids] return [self._store[doc_id] for doc_id in ids]
def get_all(self) -> dict: def get_all(self) -> dict:
"""Get all documents""" """Get all documents"""
return self.store return self._store
def delete(self, ids: Union[List[str], str]): def delete(self, ids: Union[List[str], str]):
"""Delete document by id""" """Delete document by id"""
@ -53,11 +53,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
ids = [ids] ids = [ids]
for doc_id in ids: for doc_id in ids:
del self.store[doc_id] del self._store[doc_id]
def save(self, path: Union[str, Path]): def save(self, path: Union[str, Path]):
"""Save document to path""" """Save document to path"""
store = {key: value.to_dict() for key, value in self.store.items()} store = {key: value.to_dict() for key, value in self._store.items()}
with open(path, "w") as f: with open(path, "w") as f:
json.dump(store, f) json.dump(store, f)
@ -65,4 +65,4 @@ class InMemoryDocumentStore(BaseDocumentStore):
"""Load document store from path""" """Load document store from path"""
with open(path) as f: with open(path) as f:
store = json.load(f) store = json.load(f)
self.store = {key: Document.from_dict(value) for key, value in store.items()} self._store = {key: Document.from_dict(value) for key, value in store.items()}

View File

@ -1,4 +1,5 @@
from haystack.schema import Document as HaystackDocument from haystack.schema import Document as HaystackDocument
from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument from llama_index.schema import Document as BaseDocument
SAMPLE_TEXT = "A sample Document from kotaemon" SAMPLE_TEXT = "A sample Document from kotaemon"
@ -20,3 +21,17 @@ class Document(BaseDocument):
metadata = self.metadata or {} metadata = self.metadata or {}
text = self.text text = self.text
return HaystackDocument(content=text, meta=metadata) return HaystackDocument(content=text, meta=metadata)
class RetrievedDocument(Document):
"""Subclass of Document with retrieval-related information
Attributes:
score (float): score of the document (from 0.0 to 1.0)
retrieval_metadata (dict): metadata from the retrieval process, can be used
by different components in a retrieved pipeline to communicate with each
other
"""
score: float = Field(default=0.0)
retrieval_metadata: dict = Field(default={})

View File

@ -27,7 +27,7 @@ class LangchainLLM(LLM):
self._kwargs[param] = params.pop(param) self._kwargs[param] = params.pop(param)
super().__init__(**params) super().__init__(**params)
@Param.decorate() @Param.decorate(no_cache=True)
def agent(self): def agent(self):
return self._lc_class(**self._kwargs) return self._lc_class(**self._kwargs)

View File

@ -1,8 +1,10 @@
from typing import List import uuid
from typing import List, Optional
from theflow import Node, Param from theflow import Node, Param
from ..base import BaseComponent from ..base import BaseComponent
from ..docstores import BaseDocumentStore
from ..documents.base import Document from ..documents.base import Document
from ..embeddings import BaseEmbeddings from ..embeddings import BaseEmbeddings
from ..vectorstores import BaseVectorStore from ..vectorstores import BaseVectorStore
@ -18,21 +20,30 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
""" """
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
doc_store: Optional[BaseDocumentStore] = None
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
# TODO: populate to document store as well when it's finished
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
def run_raw(self, text: str) -> None: def run_raw(self, text: str) -> None:
self.vector_store.add([self.embedding(text)]) document = Document(text=text, id_=str(uuid.uuid4()))
self.run_batch_document([document])
def run_batch_raw(self, text: List[str]) -> None: def run_batch_raw(self, text: List[str]) -> None:
self.vector_store.add(self.embedding(text)) documents = [Document(t, id_=str(uuid.uuid4())) for t in text]
self.run_batch_document(documents)
def run_document(self, text: Document) -> None: def run_document(self, text: Document) -> None:
self.vector_store.add([self.embedding(text)]) self.run_batch_document([text])
def run_batch_document(self, text: List[Document]) -> None: def run_batch_document(self, text: List[Document]) -> None:
self.vector_store.add(self.embedding(text)) embeddings = self.embedding(text)
self.vector_store.add(
embeddings=embeddings,
ids=[t.id_ for t in text],
)
if self.doc_store:
self.doc_store.add(text)
def is_document(self, text) -> bool: def is_document(self, text) -> bool:
if isinstance(text, Document): if isinstance(text, Document):

View File

@ -1,47 +1,87 @@
from typing import List from abc import abstractmethod
from typing import List, Optional
from theflow import Node, Param from theflow import Node, Param
from ..base import BaseComponent from ..base import BaseComponent
from ..documents.base import Document from ..docstores import BaseDocumentStore
from ..documents.base import Document, RetrievedDocument
from ..embeddings import BaseEmbeddings from ..embeddings import BaseEmbeddings
from ..vectorstores import BaseVectorStore from ..vectorstores import BaseVectorStore
class RetrieveDocumentFromVectorStorePipeline(BaseComponent): class BaseRetrieval(BaseComponent):
"""Define the base interface of a retrieval pipeline"""
@abstractmethod
def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]:
...
@abstractmethod
def run_batch_raw(
self, text: List[str], top_k: int = 1
) -> List[List[RetrievedDocument]]:
...
@abstractmethod
def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]:
...
@abstractmethod
def run_batch_document(
self, text: List[Document], top_k: int = 1
) -> List[List[RetrievedDocument]]:
...
class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
"""Retrieve list of documents from vector store""" """Retrieve list of documents from vector store"""
vector_store: Param[BaseVectorStore] = Param() vector_store: Param[BaseVectorStore] = Param()
doc_store: Optional[BaseDocumentStore] = None
embedding: Node[BaseEmbeddings] = Node() embedding: Node[BaseEmbeddings] = Node()
# TODO: populate to document store as well when it's finished
# TODO: refer to llama_index's storage as well # TODO: refer to llama_index's storage as well
def run_raw(self, text: str) -> List[str]: def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]:
emb = self.embedding(text) return self.run_batch_raw([text], top_k=top_k)[0]
return self.vector_store.query(embedding=emb)[2]
def run_batch_raw(
self, text: List[str], top_k: int = 1
) -> List[List[RetrievedDocument]]:
if self.doc_store is None:
raise ValueError(
"doc_store is not provided. Please provide a doc_store to "
"retrieve the documents"
)
def run_batch_raw(self, text: List[str]) -> List[List[str]]:
result = [] result = []
for each_text in text: for each_text in text:
emb = self.embedding(each_text) emb = self.embedding(each_text)
result.append(self.vector_store.query(embedding=emb)[2]) _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
docs = self.doc_store.get(ids)
each_result = [
RetrievedDocument(**doc.to_dict(), score=score)
for doc, score in zip(docs, scores)
]
result.append(each_result)
return result return result
def run_document(self, text: Document) -> List[str]: def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]:
return self.run_raw(text.text) return self.run_raw(text.text, top_k)
def run_batch_document(self, text: List[Document]) -> List[List[str]]: def run_batch_document(
input_text = [each.text for each in text] self, text: List[Document], top_k: int = 1
return self.run_batch_raw(input_text) ) -> List[List[RetrievedDocument]]:
return self.run_batch_raw(text=[t.text for t in text], top_k=top_k)
def is_document(self, text) -> bool: def is_document(self, text, *args, **kwargs) -> bool:
if isinstance(text, Document): if isinstance(text, Document):
return True return True
elif isinstance(text, List) and isinstance(text[0], Document): elif isinstance(text, List) and isinstance(text[0], Document):
return True return True
return False return False
def is_batch(self, text) -> bool: def is_batch(self, text, *args, **kwargs) -> bool:
if isinstance(text, list): if isinstance(text, list):
return True return True
return False return False

View File

@ -144,8 +144,8 @@ class LlamaIndexVectorStore(BaseVectorStore):
query_embedding=embedding, query_embedding=embedding,
similarity_top_k=top_k, similarity_top_k=top_k,
node_ids=ids, node_ids=ids,
),
**kwargs, **kwargs,
),
) )
embeddings = [] embeddings = []

View File

@ -3,7 +3,7 @@ minversion = 7.4.0
testpaths = tests testpaths = tests
addopts = -ra -q addopts = -ra -q
log_cli=true log_cli=true
log_level=DEBUG log_level=WARNING
log_format = %(asctime)s %(levelname)s %(message)s log_format = %(asctime)s %(levelname)s %(message)s
log_date_format = %Y-%m-%d %H:%M:%S log_date_format = %Y-%m-%d %H:%M:%S
log_file = logs/pytest-logs.txt log_file = logs/pytest-logs.txt

View File

@ -34,6 +34,7 @@ setuptools.setup(
"llama-index", "llama-index",
"llama-hub", "llama-hub",
"nltk", "nltk",
"gradio",
], ],
extras_require={ extras_require={
"dev": [ "dev": [

View File

@ -1,9 +1,11 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import cast
import pytest import pytest
from openai.api_resources.embedding import Embedding from openai.api_resources.embedding import Embedding
from kotaemon.docstores import InMemoryDocumentStore
from kotaemon.documents.base import Document from kotaemon.documents.base import Document
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
@ -21,6 +23,7 @@ def mock_openai_embedding(monkeypatch):
def test_indexing(mock_openai_embedding, tmp_path): def test_indexing(mock_openai_embedding, tmp_path):
db = ChromaVectorStore(path=str(tmp_path)) db = ChromaVectorStore(path=str(tmp_path))
doc_store = InMemoryDocumentStore()
embedding = AzureOpenAIEmbeddings( embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="embedding-deployment", deployment="embedding-deployment",
@ -29,15 +32,19 @@ def test_indexing(mock_openai_embedding, tmp_path):
) )
pipeline = IndexVectorStoreFromDocumentPipeline( pipeline = IndexVectorStoreFromDocumentPipeline(
vector_store=db, embedding=embedding vector_store=db, embedding=embedding, doc_store=doc_store
) )
pipeline.doc_store = cast(InMemoryDocumentStore, pipeline.doc_store)
assert pipeline.vector_store._collection.count() == 0, "Expected empty collection" assert pipeline.vector_store._collection.count() == 0, "Expected empty collection"
assert len(pipeline.doc_store._store) == 0, "Expected empty doc store"
pipeline(text=Document(text="Hello world")) pipeline(text=Document(text="Hello world"))
assert pipeline.vector_store._collection.count() == 1, "Index 1 item" assert pipeline.vector_store._collection.count() == 1, "Index 1 item"
assert len(pipeline.doc_store._store) == 1, "Expected 1 document"
def test_retrieving(mock_openai_embedding, tmp_path): def test_retrieving(mock_openai_embedding, tmp_path):
db = ChromaVectorStore(path=str(tmp_path)) db = ChromaVectorStore(path=str(tmp_path))
doc_store = InMemoryDocumentStore()
embedding = AzureOpenAIEmbeddings( embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="embedding-deployment", deployment="embedding-deployment",
@ -46,14 +53,14 @@ def test_retrieving(mock_openai_embedding, tmp_path):
) )
index_pipeline = IndexVectorStoreFromDocumentPipeline( index_pipeline = IndexVectorStoreFromDocumentPipeline(
vector_store=db, embedding=embedding vector_store=db, embedding=embedding, doc_store=doc_store
) )
retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline( retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=db, embedding=embedding vector_store=db, doc_store=doc_store, embedding=embedding
) )
index_pipeline(text=Document(text="Hello world")) index_pipeline(text=Document(text="Hello world"))
output = retrieval_pipeline(text=["Hello world", "Hello world"]) output = retrieval_pipeline(text=["Hello world", "Hello world"])
assert len(output) == 2, "Expected 2 results" assert len(output) == 2, "Expect 2 results"
assert output[0] == output[1], "Expected identical results" assert output[0] == output[1], "Expect identical results"

86
tests/test_promptui.py Normal file
View File

@ -0,0 +1,86 @@
import pytest
from kotaemon.contribs.promptui.config import export_pipeline_to_config
from kotaemon.contribs.promptui.ui import build_from_dict
@pytest.fixture()
def simple_pipeline_cls(tmp_path):
"""Create a pipeline class that can be used"""
from typing import List
from theflow import Node
from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms.completions.openai import AzureOpenAI
from kotaemon.pipelines.retrieving import (
RetrieveDocumentFromVectorStorePipeline,
)
from kotaemon.vectorstores import ChromaVectorStore
class Pipeline(BaseComponent):
vectorstore_path: str = str(tmp_path)
llm: Node[AzureOpenAI] = Node(
default=AzureOpenAI,
default_kwargs={
"openai_api_base": "https://test.openai.azure.com/",
"openai_api_key": "some-key",
"openai_api_version": "2023-03-15-preview",
"deployment_name": "gpt35turbo",
"temperature": 0,
"request_timeout": 60,
},
)
@Node.decorate(depends_on=["vectorstore_path"])
def retrieving_pipeline(self):
vector_store = ChromaVectorStore(self.vectorstore_path)
embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
)
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store, embedding=embedding
)
def run_raw(self, text: str) -> str:
matched_texts: List[str] = self.retrieving_pipeline(text)
return self.llm("\n".join(matched_texts)).text[0]
return Pipeline
Pipeline = simple_pipeline_cls
class TestPromptConfig:
def test_export_prompt_config(self, simple_pipeline_cls):
"""Test if the prompt config is exported correctly"""
pipeline = simple_pipeline_cls()
config_dict = export_pipeline_to_config(pipeline)
config = list(config_dict.values())[0]
assert "inputs" in config, "inputs should be in config"
assert "text" in config["inputs"], "inputs should have config"
assert "params" in config, "params should be in config"
assert "vectorstore_path" in config["params"]
assert "llm.deployment_name" in config["params"]
assert "llm.openai_api_base" in config["params"]
assert "llm.openai_api_key" in config["params"]
assert "llm.openai_api_version" in config["params"]
assert "llm.request_timeout" in config["params"]
assert "llm.temperature" in config["params"]
class TestPromptUI:
def test_uigeneration(self, simple_pipeline_cls):
"""Test if the gradio UI is exposed without any problem"""
pipeline = simple_pipeline_cls()
config = export_pipeline_to_config(pipeline)
build_from_dict(config)