[AUR-338, AUR-406, AUR-407] Export pipeline to config for PromptUI. Construct PromptUI dynamically based on config. (#16)
From pipeline > config > UI. Provide example project for promptui - Pipeline to config: `kotaemon.contribs.promptui.config.export_pipeline_to_config`. The config follows schema specified in this document: https://cinnamon-ai.atlassian.net/wiki/spaces/ATM/pages/2748711193/Technical+Detail. Note: this implementation exclude the logs, which will be handled in AUR-408. - Config to UI: `kotaemon.contribs.promptui.build_from_yaml` - Example project is located at `examples/promptui/`
This commit is contained in:
parent
c329c4c03f
commit
c6dd01e820
Binary file not shown.
|
@ -1 +1 @@
|
||||||
credentials.txt:1e17fa46dd8353b5ded588b32983ac7d800e70fd16bc5831663b9aaefc409011
|
credentials.txt:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5
|
||||||
|
|
|
@ -25,7 +25,7 @@ repos:
|
||||||
rev: 4.0.1
|
rev: 4.0.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
args: ["--max-line-length", "88"]
|
args: ["--max-line-length", "88", "--extend-ignore", "E203"]
|
||||||
- repo: https://github.com/myint/autoflake
|
- repo: https://github.com/myint/autoflake
|
||||||
rev: v1.4
|
rev: v1.4
|
||||||
hooks:
|
hooks:
|
||||||
|
@ -47,4 +47,5 @@ repos:
|
||||||
rev: "v1.5.1"
|
rev: "v1.5.1"
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
|
additional_dependencies: [types-PyYAML==6.0.12.11]
|
||||||
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
args: ["--check-untyped-defs", "--ignore-missing-imports"]
|
||||||
|
|
Binary file not shown.
20
knowledgehub/contribs/promptui/base.py
Normal file
20
knowledgehub/contribs/promptui/base.py
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
COMPONENTS_CLASS = {
|
||||||
|
"text": gr.components.Textbox,
|
||||||
|
"checkbox": gr.components.CheckboxGroup,
|
||||||
|
"dropdown": gr.components.Dropdown,
|
||||||
|
"file": gr.components.File,
|
||||||
|
"image": gr.components.Image,
|
||||||
|
"number": gr.components.Number,
|
||||||
|
"radio": gr.components.Radio,
|
||||||
|
"slider": gr.components.Slider,
|
||||||
|
}
|
||||||
|
SUPPORTED_COMPONENTS = set(COMPONENTS_CLASS.keys())
|
||||||
|
DEFAULT_COMPONENT_BY_TYPES = {
|
||||||
|
"str": "text",
|
||||||
|
"bool": "checkbox",
|
||||||
|
"int": "number",
|
||||||
|
"float": "number",
|
||||||
|
"list": "dropdown",
|
||||||
|
}
|
|
@ -1 +1,132 @@
|
||||||
"""Get config from Pipeline"""
|
"""Get config from Pipeline"""
|
||||||
|
import inspect
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional, Type, Union
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from ...base import BaseComponent
|
||||||
|
from .base import DEFAULT_COMPONENT_BY_TYPES
|
||||||
|
|
||||||
|
|
||||||
|
def config_from_value(value: Any) -> dict:
|
||||||
|
"""Get the config from default value
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value (Any): default value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: config
|
||||||
|
"""
|
||||||
|
component = DEFAULT_COMPONENT_BY_TYPES.get(type(value).__name__, "text")
|
||||||
|
return {
|
||||||
|
"component": component,
|
||||||
|
"params": {
|
||||||
|
"value": value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_param(param: dict) -> dict:
|
||||||
|
"""Convert param definition into promptui-compliant config
|
||||||
|
|
||||||
|
Supported gradio's UI components are (https://www.gradio.app/docs/components)
|
||||||
|
- CheckBoxGroup: list (multi select)
|
||||||
|
- DropDown: list (single select)
|
||||||
|
- File
|
||||||
|
- Image
|
||||||
|
- Number: int / float
|
||||||
|
- Radio: list (single select)
|
||||||
|
- Slider: int / float
|
||||||
|
- TextBox: str
|
||||||
|
"""
|
||||||
|
params = {}
|
||||||
|
default = param.get("default", None)
|
||||||
|
if isinstance(default, str) and default.startswith("{{") and default.endswith("}}"):
|
||||||
|
default = None
|
||||||
|
if default is not None:
|
||||||
|
params["value"] = default
|
||||||
|
|
||||||
|
type_: str = type(default).__name__ if default is not None else ""
|
||||||
|
ui_component = DEFAULT_COMPONENT_BY_TYPES.get(type_, "text")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"component": ui_component,
|
||||||
|
"params": params,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def handle_node(node: dict) -> dict:
|
||||||
|
"""Convert node definition into promptui-compliant config"""
|
||||||
|
config = {}
|
||||||
|
for name, param_def in node.get("params", {}).items():
|
||||||
|
if isinstance(param_def["default_callback"], str):
|
||||||
|
continue
|
||||||
|
config[name] = handle_param(param_def)
|
||||||
|
for name, node_def in node.get("nodes", {}).items():
|
||||||
|
if isinstance(node_def["default_callback"], str):
|
||||||
|
continue
|
||||||
|
for key, value in handle_node(node_def["default"]).items():
|
||||||
|
config[f"{name}.{key}"] = value
|
||||||
|
for key, value in node_def["default_kwargs"].items():
|
||||||
|
config[f"{name}.{key}"] = config_from_value(value)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
|
||||||
|
"""Get the input from the pipeline"""
|
||||||
|
if not hasattr(pipeline, "run_raw"):
|
||||||
|
return {}
|
||||||
|
signature = inspect.signature(pipeline.run_raw)
|
||||||
|
inputs: Dict[str, Dict] = {}
|
||||||
|
for name, param in signature.parameters.items():
|
||||||
|
if name in ["self", "args", "kwargs"]:
|
||||||
|
continue
|
||||||
|
input_def: Dict[str, Optional[Any]] = {"component": "text"}
|
||||||
|
default = param.default
|
||||||
|
if default is param.empty:
|
||||||
|
inputs[name] = input_def
|
||||||
|
continue
|
||||||
|
|
||||||
|
params = {}
|
||||||
|
params["value"] = default
|
||||||
|
type_ = type(default).__name__ if default is not None else None
|
||||||
|
ui_component = None
|
||||||
|
if type_ is not None:
|
||||||
|
ui_component = "text"
|
||||||
|
|
||||||
|
input_def["component"] = ui_component
|
||||||
|
input_def["params"] = params
|
||||||
|
|
||||||
|
inputs[name] = input_def
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
def export_pipeline_to_config(
|
||||||
|
pipeline: Union[BaseComponent, Type[BaseComponent]],
|
||||||
|
path: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Export a pipeline to a promptui-compliant config dict"""
|
||||||
|
if inspect.isclass(pipeline):
|
||||||
|
pipeline = pipeline()
|
||||||
|
|
||||||
|
pipeline_def = pipeline.describe()
|
||||||
|
config = {
|
||||||
|
f"{pipeline.__module__}.{pipeline.__class__.__name__}": {
|
||||||
|
"params": handle_node(pipeline_def),
|
||||||
|
"inputs": handle_input(pipeline),
|
||||||
|
"outputs": [{"step": ".", "component": "text"}],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if path is not None:
|
||||||
|
old_config = config
|
||||||
|
if Path(path).is_file():
|
||||||
|
with open(path) as f:
|
||||||
|
old_config = yaml.safe_load(f)
|
||||||
|
old_config.update(config)
|
||||||
|
with open(path, "w") as f:
|
||||||
|
yaml.safe_dump(old_config, f)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
|
@ -1,6 +1,151 @@
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
import yaml
|
||||||
|
from theflow.utils.modules import import_dotted_string
|
||||||
|
|
||||||
|
from kotaemon.contribs.promptui.base import COMPONENTS_CLASS, SUPPORTED_COMPONENTS
|
||||||
|
|
||||||
|
USAGE_INSTRUCTION = """In case of errors, you can:
|
||||||
|
|
||||||
|
- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon
|
||||||
|
- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_component(component_def: dict) -> gr.components.Component:
|
||||||
|
"""Get the component based on component definition"""
|
||||||
|
component_cls = None
|
||||||
|
|
||||||
|
if "component" in component_def:
|
||||||
|
component = component_def["component"]
|
||||||
|
if component not in SUPPORTED_COMPONENTS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported UI component: {component}. "
|
||||||
|
f"Must be one of {SUPPORTED_COMPONENTS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
component_cls = COMPONENTS_CLASS[component]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot decide the component from {component_def}. "
|
||||||
|
"Please specify `component` with 1 of the following "
|
||||||
|
f"values: {SUPPORTED_COMPONENTS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return component_cls(**component_def.get("params", {}))
|
||||||
|
|
||||||
|
|
||||||
|
def construct_ui(config, func_run, func_export) -> gr.Blocks:
|
||||||
"""Create UI from config file. Execute the UI from config file
|
"""Create UI from config file. Execute the UI from config file
|
||||||
|
|
||||||
- Can do now: Log from stdout to UI
|
- Can do now: Log from stdout to UI
|
||||||
- In the future, we can provide some hooks and callbacks to let developers better
|
- In the future, we can provide some hooks and callbacks to let developers better
|
||||||
fine-tune the UI behavior.
|
fine-tune the UI behavior.
|
||||||
"""
|
"""
|
||||||
|
inputs, outputs, params = [], [], []
|
||||||
|
for name, component_def in config.get("inputs", {}).items():
|
||||||
|
if "params" not in component_def:
|
||||||
|
component_def["params"] = {}
|
||||||
|
component_def["params"]["interactive"] = True
|
||||||
|
component = get_component(component_def)
|
||||||
|
if hasattr(component, "label") and not component.label: # type: ignore
|
||||||
|
component.label = name # type: ignore
|
||||||
|
|
||||||
|
inputs.append(component)
|
||||||
|
|
||||||
|
for name, component_def in config.get("params", {}).items():
|
||||||
|
if "params" not in component_def:
|
||||||
|
component_def["params"] = {}
|
||||||
|
component_def["params"]["interactive"] = True
|
||||||
|
component = get_component(component_def)
|
||||||
|
if hasattr(component, "label") and not component.label: # type: ignore
|
||||||
|
component.label = name # type: ignore
|
||||||
|
|
||||||
|
params.append(component)
|
||||||
|
|
||||||
|
for idx, component_def in enumerate(config.get("outputs", [])):
|
||||||
|
if "params" not in component_def:
|
||||||
|
component_def["params"] = {}
|
||||||
|
component_def["params"]["interactive"] = False
|
||||||
|
component = get_component(component_def)
|
||||||
|
if hasattr(component, "label") and not component.label: # type: ignore
|
||||||
|
component.label = f"Output {idx}"
|
||||||
|
|
||||||
|
outputs.append(component)
|
||||||
|
|
||||||
|
temp = gr.Tab
|
||||||
|
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
|
||||||
|
with gr.Accordion(label="Usage", open=False):
|
||||||
|
gr.Markdown(USAGE_INSTRUCTION)
|
||||||
|
with gr.Row():
|
||||||
|
run_btn = gr.Button("Run")
|
||||||
|
run_btn.click(func_run, inputs=inputs + params, outputs=outputs)
|
||||||
|
export_btn = gr.Button("Export")
|
||||||
|
export_btn.click(func_export, inputs=None, outputs=None)
|
||||||
|
with gr.Row():
|
||||||
|
with gr.Column():
|
||||||
|
with temp("Inputs"):
|
||||||
|
for component in inputs:
|
||||||
|
component.render()
|
||||||
|
with temp("Params"):
|
||||||
|
for component in params:
|
||||||
|
component.render()
|
||||||
|
with gr.Column():
|
||||||
|
for component in outputs:
|
||||||
|
component.render()
|
||||||
|
|
||||||
|
return demo
|
||||||
|
|
||||||
|
|
||||||
|
def build_pipeline_ui(config: dict, pipeline_def):
|
||||||
|
"""Build a tab from config file"""
|
||||||
|
inputs_name = list(config.get("inputs", {}).keys())
|
||||||
|
params_name = list(config.get("params", {}).keys())
|
||||||
|
outputs_def = config.get("outputs", [])
|
||||||
|
|
||||||
|
def run_func(*args):
|
||||||
|
inputs = {
|
||||||
|
name: value for name, value in zip(inputs_name, args[: len(inputs_name)])
|
||||||
|
}
|
||||||
|
params = {
|
||||||
|
name: value for name, value in zip(params_name, args[len(inputs_name) :])
|
||||||
|
}
|
||||||
|
pipeline = pipeline_def()
|
||||||
|
pipeline.set(params)
|
||||||
|
pipeline(**inputs)
|
||||||
|
if outputs_def:
|
||||||
|
outputs = []
|
||||||
|
for output_def in outputs_def:
|
||||||
|
output = pipeline.last_run.logs(output_def["step"])
|
||||||
|
if "item" in output_def:
|
||||||
|
output = output[output_def["item"]]
|
||||||
|
outputs.append(output)
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
# TODO: export_func is None for now
|
||||||
|
return construct_ui(config, run_func, None)
|
||||||
|
|
||||||
|
|
||||||
|
def build_from_dict(config: Union[str, dict]):
|
||||||
|
"""Build a full UI from YAML config file"""
|
||||||
|
|
||||||
|
if isinstance(config, str):
|
||||||
|
with open(config) as f:
|
||||||
|
config_dict: dict = yaml.safe_load(f)
|
||||||
|
elif isinstance(config, dict):
|
||||||
|
config_dict = config
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"config must be either a yaml path or a dict, got {type(config)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
demos = []
|
||||||
|
for key, value in config_dict.items():
|
||||||
|
pipeline_def = import_dotted_string(key, safe=False)
|
||||||
|
demos.append(build_pipeline_ui(value, pipeline_def))
|
||||||
|
if len(demos) == 1:
|
||||||
|
demo = demos[0]
|
||||||
|
else:
|
||||||
|
demo = gr.TabbedInterface(demos, list(config_dict.keys()))
|
||||||
|
|
||||||
|
return demo
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from .base import BaseDocumentStore
|
from .base import BaseDocumentStore
|
||||||
from .simple import InMemoryDocumentStore
|
from .in_memory import InMemoryDocumentStore
|
||||||
|
|
||||||
__all__ = ["BaseDocumentStore", "InMemoryDocumentStore"]
|
__all__ = ["BaseDocumentStore", "InMemoryDocumentStore"]
|
||||||
|
|
|
@ -10,7 +10,7 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||||
"""Simple memory document store that store document in a dictionary"""
|
"""Simple memory document store that store document in a dictionary"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.store = {}
|
self._store = {}
|
||||||
|
|
||||||
def add(
|
def add(
|
||||||
self,
|
self,
|
||||||
|
@ -32,20 +32,20 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
|
||||||
for doc_id, doc in zip(doc_ids, docs):
|
for doc_id, doc in zip(doc_ids, docs):
|
||||||
if doc_id in self.store and not exist_ok:
|
if doc_id in self._store and not exist_ok:
|
||||||
raise ValueError(f"Document with id {doc_id} already exist")
|
raise ValueError(f"Document with id {doc_id} already exist")
|
||||||
self.store[doc_id] = doc
|
self._store[doc_id] = doc
|
||||||
|
|
||||||
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
def get(self, ids: Union[List[str], str]) -> List[Document]:
|
||||||
"""Get document by id"""
|
"""Get document by id"""
|
||||||
if not isinstance(ids, list):
|
if not isinstance(ids, list):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
|
|
||||||
return [self.store[doc_id] for doc_id in ids]
|
return [self._store[doc_id] for doc_id in ids]
|
||||||
|
|
||||||
def get_all(self) -> dict:
|
def get_all(self) -> dict:
|
||||||
"""Get all documents"""
|
"""Get all documents"""
|
||||||
return self.store
|
return self._store
|
||||||
|
|
||||||
def delete(self, ids: Union[List[str], str]):
|
def delete(self, ids: Union[List[str], str]):
|
||||||
"""Delete document by id"""
|
"""Delete document by id"""
|
||||||
|
@ -53,11 +53,11 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||||
ids = [ids]
|
ids = [ids]
|
||||||
|
|
||||||
for doc_id in ids:
|
for doc_id in ids:
|
||||||
del self.store[doc_id]
|
del self._store[doc_id]
|
||||||
|
|
||||||
def save(self, path: Union[str, Path]):
|
def save(self, path: Union[str, Path]):
|
||||||
"""Save document to path"""
|
"""Save document to path"""
|
||||||
store = {key: value.to_dict() for key, value in self.store.items()}
|
store = {key: value.to_dict() for key, value in self._store.items()}
|
||||||
with open(path, "w") as f:
|
with open(path, "w") as f:
|
||||||
json.dump(store, f)
|
json.dump(store, f)
|
||||||
|
|
||||||
|
@ -65,4 +65,4 @@ class InMemoryDocumentStore(BaseDocumentStore):
|
||||||
"""Load document store from path"""
|
"""Load document store from path"""
|
||||||
with open(path) as f:
|
with open(path) as f:
|
||||||
store = json.load(f)
|
store = json.load(f)
|
||||||
self.store = {key: Document.from_dict(value) for key, value in store.items()}
|
self._store = {key: Document.from_dict(value) for key, value in store.items()}
|
|
@ -1,4 +1,5 @@
|
||||||
from haystack.schema import Document as HaystackDocument
|
from haystack.schema import Document as HaystackDocument
|
||||||
|
from llama_index.bridge.pydantic import Field
|
||||||
from llama_index.schema import Document as BaseDocument
|
from llama_index.schema import Document as BaseDocument
|
||||||
|
|
||||||
SAMPLE_TEXT = "A sample Document from kotaemon"
|
SAMPLE_TEXT = "A sample Document from kotaemon"
|
||||||
|
@ -20,3 +21,17 @@ class Document(BaseDocument):
|
||||||
metadata = self.metadata or {}
|
metadata = self.metadata or {}
|
||||||
text = self.text
|
text = self.text
|
||||||
return HaystackDocument(content=text, meta=metadata)
|
return HaystackDocument(content=text, meta=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievedDocument(Document):
|
||||||
|
"""Subclass of Document with retrieval-related information
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
score (float): score of the document (from 0.0 to 1.0)
|
||||||
|
retrieval_metadata (dict): metadata from the retrieval process, can be used
|
||||||
|
by different components in a retrieved pipeline to communicate with each
|
||||||
|
other
|
||||||
|
"""
|
||||||
|
|
||||||
|
score: float = Field(default=0.0)
|
||||||
|
retrieval_metadata: dict = Field(default={})
|
||||||
|
|
|
@ -27,7 +27,7 @@ class LangchainLLM(LLM):
|
||||||
self._kwargs[param] = params.pop(param)
|
self._kwargs[param] = params.pop(param)
|
||||||
super().__init__(**params)
|
super().__init__(**params)
|
||||||
|
|
||||||
@Param.decorate()
|
@Param.decorate(no_cache=True)
|
||||||
def agent(self):
|
def agent(self):
|
||||||
return self._lc_class(**self._kwargs)
|
return self._lc_class(**self._kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
from typing import List
|
import uuid
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
from ..base import BaseComponent
|
from ..base import BaseComponent
|
||||||
|
from ..docstores import BaseDocumentStore
|
||||||
from ..documents.base import Document
|
from ..documents.base import Document
|
||||||
from ..embeddings import BaseEmbeddings
|
from ..embeddings import BaseEmbeddings
|
||||||
from ..vectorstores import BaseVectorStore
|
from ..vectorstores import BaseVectorStore
|
||||||
|
@ -18,21 +20,30 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
vector_store: Param[BaseVectorStore] = Param()
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
|
doc_store: Optional[BaseDocumentStore] = None
|
||||||
embedding: Node[BaseEmbeddings] = Node()
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
# TODO: populate to document store as well when it's finished
|
|
||||||
# TODO: refer to llama_index's storage as well
|
# TODO: refer to llama_index's storage as well
|
||||||
|
|
||||||
def run_raw(self, text: str) -> None:
|
def run_raw(self, text: str) -> None:
|
||||||
self.vector_store.add([self.embedding(text)])
|
document = Document(text=text, id_=str(uuid.uuid4()))
|
||||||
|
self.run_batch_document([document])
|
||||||
|
|
||||||
def run_batch_raw(self, text: List[str]) -> None:
|
def run_batch_raw(self, text: List[str]) -> None:
|
||||||
self.vector_store.add(self.embedding(text))
|
documents = [Document(t, id_=str(uuid.uuid4())) for t in text]
|
||||||
|
self.run_batch_document(documents)
|
||||||
|
|
||||||
def run_document(self, text: Document) -> None:
|
def run_document(self, text: Document) -> None:
|
||||||
self.vector_store.add([self.embedding(text)])
|
self.run_batch_document([text])
|
||||||
|
|
||||||
def run_batch_document(self, text: List[Document]) -> None:
|
def run_batch_document(self, text: List[Document]) -> None:
|
||||||
self.vector_store.add(self.embedding(text))
|
embeddings = self.embedding(text)
|
||||||
|
self.vector_store.add(
|
||||||
|
embeddings=embeddings,
|
||||||
|
ids=[t.id_ for t in text],
|
||||||
|
)
|
||||||
|
if self.doc_store:
|
||||||
|
self.doc_store.add(text)
|
||||||
|
|
||||||
def is_document(self, text) -> bool:
|
def is_document(self, text) -> bool:
|
||||||
if isinstance(text, Document):
|
if isinstance(text, Document):
|
||||||
|
|
|
@ -1,47 +1,87 @@
|
||||||
from typing import List
|
from abc import abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from theflow import Node, Param
|
from theflow import Node, Param
|
||||||
|
|
||||||
from ..base import BaseComponent
|
from ..base import BaseComponent
|
||||||
from ..documents.base import Document
|
from ..docstores import BaseDocumentStore
|
||||||
|
from ..documents.base import Document, RetrievedDocument
|
||||||
from ..embeddings import BaseEmbeddings
|
from ..embeddings import BaseEmbeddings
|
||||||
from ..vectorstores import BaseVectorStore
|
from ..vectorstores import BaseVectorStore
|
||||||
|
|
||||||
|
|
||||||
class RetrieveDocumentFromVectorStorePipeline(BaseComponent):
|
class BaseRetrieval(BaseComponent):
|
||||||
|
"""Define the base interface of a retrieval pipeline"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_raw(
|
||||||
|
self, text: List[str], top_k: int = 1
|
||||||
|
) -> List[List[RetrievedDocument]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def run_batch_document(
|
||||||
|
self, text: List[Document], top_k: int = 1
|
||||||
|
) -> List[List[RetrievedDocument]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval):
|
||||||
"""Retrieve list of documents from vector store"""
|
"""Retrieve list of documents from vector store"""
|
||||||
|
|
||||||
vector_store: Param[BaseVectorStore] = Param()
|
vector_store: Param[BaseVectorStore] = Param()
|
||||||
|
doc_store: Optional[BaseDocumentStore] = None
|
||||||
embedding: Node[BaseEmbeddings] = Node()
|
embedding: Node[BaseEmbeddings] = Node()
|
||||||
# TODO: populate to document store as well when it's finished
|
|
||||||
# TODO: refer to llama_index's storage as well
|
# TODO: refer to llama_index's storage as well
|
||||||
|
|
||||||
def run_raw(self, text: str) -> List[str]:
|
def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]:
|
||||||
emb = self.embedding(text)
|
return self.run_batch_raw([text], top_k=top_k)[0]
|
||||||
return self.vector_store.query(embedding=emb)[2]
|
|
||||||
|
def run_batch_raw(
|
||||||
|
self, text: List[str], top_k: int = 1
|
||||||
|
) -> List[List[RetrievedDocument]]:
|
||||||
|
if self.doc_store is None:
|
||||||
|
raise ValueError(
|
||||||
|
"doc_store is not provided. Please provide a doc_store to "
|
||||||
|
"retrieve the documents"
|
||||||
|
)
|
||||||
|
|
||||||
def run_batch_raw(self, text: List[str]) -> List[List[str]]:
|
|
||||||
result = []
|
result = []
|
||||||
for each_text in text:
|
for each_text in text:
|
||||||
emb = self.embedding(each_text)
|
emb = self.embedding(each_text)
|
||||||
result.append(self.vector_store.query(embedding=emb)[2])
|
_, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k)
|
||||||
|
docs = self.doc_store.get(ids)
|
||||||
|
each_result = [
|
||||||
|
RetrievedDocument(**doc.to_dict(), score=score)
|
||||||
|
for doc, score in zip(docs, scores)
|
||||||
|
]
|
||||||
|
result.append(each_result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def run_document(self, text: Document) -> List[str]:
|
def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]:
|
||||||
return self.run_raw(text.text)
|
return self.run_raw(text.text, top_k)
|
||||||
|
|
||||||
def run_batch_document(self, text: List[Document]) -> List[List[str]]:
|
def run_batch_document(
|
||||||
input_text = [each.text for each in text]
|
self, text: List[Document], top_k: int = 1
|
||||||
return self.run_batch_raw(input_text)
|
) -> List[List[RetrievedDocument]]:
|
||||||
|
return self.run_batch_raw(text=[t.text for t in text], top_k=top_k)
|
||||||
|
|
||||||
def is_document(self, text) -> bool:
|
def is_document(self, text, *args, **kwargs) -> bool:
|
||||||
if isinstance(text, Document):
|
if isinstance(text, Document):
|
||||||
return True
|
return True
|
||||||
elif isinstance(text, List) and isinstance(text[0], Document):
|
elif isinstance(text, List) and isinstance(text[0], Document):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_batch(self, text) -> bool:
|
def is_batch(self, text, *args, **kwargs) -> bool:
|
||||||
if isinstance(text, list):
|
if isinstance(text, list):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -144,8 +144,8 @@ class LlamaIndexVectorStore(BaseVectorStore):
|
||||||
query_embedding=embedding,
|
query_embedding=embedding,
|
||||||
similarity_top_k=top_k,
|
similarity_top_k=top_k,
|
||||||
node_ids=ids,
|
node_ids=ids,
|
||||||
),
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
embeddings = []
|
embeddings = []
|
||||||
|
|
|
@ -3,7 +3,7 @@ minversion = 7.4.0
|
||||||
testpaths = tests
|
testpaths = tests
|
||||||
addopts = -ra -q
|
addopts = -ra -q
|
||||||
log_cli=true
|
log_cli=true
|
||||||
log_level=DEBUG
|
log_level=WARNING
|
||||||
log_format = %(asctime)s %(levelname)s %(message)s
|
log_format = %(asctime)s %(levelname)s %(message)s
|
||||||
log_date_format = %Y-%m-%d %H:%M:%S
|
log_date_format = %Y-%m-%d %H:%M:%S
|
||||||
log_file = logs/pytest-logs.txt
|
log_file = logs/pytest-logs.txt
|
||||||
|
|
1
setup.py
1
setup.py
|
@ -34,6 +34,7 @@ setuptools.setup(
|
||||||
"llama-index",
|
"llama-index",
|
||||||
"llama-hub",
|
"llama-hub",
|
||||||
"nltk",
|
"nltk",
|
||||||
|
"gradio",
|
||||||
],
|
],
|
||||||
extras_require={
|
extras_require={
|
||||||
"dev": [
|
"dev": [
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from openai.api_resources.embedding import Embedding
|
from openai.api_resources.embedding import Embedding
|
||||||
|
|
||||||
|
from kotaemon.docstores import InMemoryDocumentStore
|
||||||
from kotaemon.documents.base import Document
|
from kotaemon.documents.base import Document
|
||||||
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
from kotaemon.embeddings.openai import AzureOpenAIEmbeddings
|
||||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||||
|
@ -21,6 +23,7 @@ def mock_openai_embedding(monkeypatch):
|
||||||
|
|
||||||
def test_indexing(mock_openai_embedding, tmp_path):
|
def test_indexing(mock_openai_embedding, tmp_path):
|
||||||
db = ChromaVectorStore(path=str(tmp_path))
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
embedding = AzureOpenAIEmbeddings(
|
embedding = AzureOpenAIEmbeddings(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="embedding-deployment",
|
deployment="embedding-deployment",
|
||||||
|
@ -29,15 +32,19 @@ def test_indexing(mock_openai_embedding, tmp_path):
|
||||||
)
|
)
|
||||||
|
|
||||||
pipeline = IndexVectorStoreFromDocumentPipeline(
|
pipeline = IndexVectorStoreFromDocumentPipeline(
|
||||||
vector_store=db, embedding=embedding
|
vector_store=db, embedding=embedding, doc_store=doc_store
|
||||||
)
|
)
|
||||||
|
pipeline.doc_store = cast(InMemoryDocumentStore, pipeline.doc_store)
|
||||||
assert pipeline.vector_store._collection.count() == 0, "Expected empty collection"
|
assert pipeline.vector_store._collection.count() == 0, "Expected empty collection"
|
||||||
|
assert len(pipeline.doc_store._store) == 0, "Expected empty doc store"
|
||||||
pipeline(text=Document(text="Hello world"))
|
pipeline(text=Document(text="Hello world"))
|
||||||
assert pipeline.vector_store._collection.count() == 1, "Index 1 item"
|
assert pipeline.vector_store._collection.count() == 1, "Index 1 item"
|
||||||
|
assert len(pipeline.doc_store._store) == 1, "Expected 1 document"
|
||||||
|
|
||||||
|
|
||||||
def test_retrieving(mock_openai_embedding, tmp_path):
|
def test_retrieving(mock_openai_embedding, tmp_path):
|
||||||
db = ChromaVectorStore(path=str(tmp_path))
|
db = ChromaVectorStore(path=str(tmp_path))
|
||||||
|
doc_store = InMemoryDocumentStore()
|
||||||
embedding = AzureOpenAIEmbeddings(
|
embedding = AzureOpenAIEmbeddings(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="embedding-deployment",
|
deployment="embedding-deployment",
|
||||||
|
@ -46,14 +53,14 @@ def test_retrieving(mock_openai_embedding, tmp_path):
|
||||||
)
|
)
|
||||||
|
|
||||||
index_pipeline = IndexVectorStoreFromDocumentPipeline(
|
index_pipeline = IndexVectorStoreFromDocumentPipeline(
|
||||||
vector_store=db, embedding=embedding
|
vector_store=db, embedding=embedding, doc_store=doc_store
|
||||||
)
|
)
|
||||||
retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||||
vector_store=db, embedding=embedding
|
vector_store=db, doc_store=doc_store, embedding=embedding
|
||||||
)
|
)
|
||||||
|
|
||||||
index_pipeline(text=Document(text="Hello world"))
|
index_pipeline(text=Document(text="Hello world"))
|
||||||
output = retrieval_pipeline(text=["Hello world", "Hello world"])
|
output = retrieval_pipeline(text=["Hello world", "Hello world"])
|
||||||
|
|
||||||
assert len(output) == 2, "Expected 2 results"
|
assert len(output) == 2, "Expect 2 results"
|
||||||
assert output[0] == output[1], "Expected identical results"
|
assert output[0] == output[1], "Expect identical results"
|
||||||
|
|
86
tests/test_promptui.py
Normal file
86
tests/test_promptui.py
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from kotaemon.contribs.promptui.config import export_pipeline_to_config
|
||||||
|
from kotaemon.contribs.promptui.ui import build_from_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def simple_pipeline_cls(tmp_path):
|
||||||
|
"""Create a pipeline class that can be used"""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from theflow import Node
|
||||||
|
|
||||||
|
from kotaemon.base import BaseComponent
|
||||||
|
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||||
|
from kotaemon.llms.completions.openai import AzureOpenAI
|
||||||
|
from kotaemon.pipelines.retrieving import (
|
||||||
|
RetrieveDocumentFromVectorStorePipeline,
|
||||||
|
)
|
||||||
|
from kotaemon.vectorstores import ChromaVectorStore
|
||||||
|
|
||||||
|
class Pipeline(BaseComponent):
|
||||||
|
vectorstore_path: str = str(tmp_path)
|
||||||
|
llm: Node[AzureOpenAI] = Node(
|
||||||
|
default=AzureOpenAI,
|
||||||
|
default_kwargs={
|
||||||
|
"openai_api_base": "https://test.openai.azure.com/",
|
||||||
|
"openai_api_key": "some-key",
|
||||||
|
"openai_api_version": "2023-03-15-preview",
|
||||||
|
"deployment_name": "gpt35turbo",
|
||||||
|
"temperature": 0,
|
||||||
|
"request_timeout": 60,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@Node.decorate(depends_on=["vectorstore_path"])
|
||||||
|
def retrieving_pipeline(self):
|
||||||
|
vector_store = ChromaVectorStore(self.vectorstore_path)
|
||||||
|
embedding = AzureOpenAIEmbeddings(
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
deployment="embedding-deployment",
|
||||||
|
openai_api_base="https://test.openai.azure.com/",
|
||||||
|
openai_api_key="some-key",
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrieveDocumentFromVectorStorePipeline(
|
||||||
|
vector_store=vector_store, embedding=embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_raw(self, text: str) -> str:
|
||||||
|
matched_texts: List[str] = self.retrieving_pipeline(text)
|
||||||
|
return self.llm("\n".join(matched_texts)).text[0]
|
||||||
|
|
||||||
|
return Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
Pipeline = simple_pipeline_cls
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptConfig:
|
||||||
|
def test_export_prompt_config(self, simple_pipeline_cls):
|
||||||
|
"""Test if the prompt config is exported correctly"""
|
||||||
|
pipeline = simple_pipeline_cls()
|
||||||
|
config_dict = export_pipeline_to_config(pipeline)
|
||||||
|
config = list(config_dict.values())[0]
|
||||||
|
|
||||||
|
assert "inputs" in config, "inputs should be in config"
|
||||||
|
assert "text" in config["inputs"], "inputs should have config"
|
||||||
|
|
||||||
|
assert "params" in config, "params should be in config"
|
||||||
|
assert "vectorstore_path" in config["params"]
|
||||||
|
assert "llm.deployment_name" in config["params"]
|
||||||
|
assert "llm.openai_api_base" in config["params"]
|
||||||
|
assert "llm.openai_api_key" in config["params"]
|
||||||
|
assert "llm.openai_api_version" in config["params"]
|
||||||
|
assert "llm.request_timeout" in config["params"]
|
||||||
|
assert "llm.temperature" in config["params"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptUI:
|
||||||
|
def test_uigeneration(self, simple_pipeline_cls):
|
||||||
|
"""Test if the gradio UI is exposed without any problem"""
|
||||||
|
pipeline = simple_pipeline_cls()
|
||||||
|
config = export_pipeline_to_config(pipeline)
|
||||||
|
|
||||||
|
build_from_dict(config)
|
Loading…
Reference in New Issue
Block a user