From c6dd01e8203aacdecd01541cac8429c596fe889a Mon Sep 17 00:00:00 2001 From: "Nguyen Trung Duc (john)" Date: Thu, 21 Sep 2023 14:27:23 +0700 Subject: [PATCH] [AUR-338, AUR-406, AUR-407] Export pipeline to config for PromptUI. Construct PromptUI dynamically based on config. (#16) From pipeline > config > UI. Provide example project for promptui - Pipeline to config: `kotaemon.contribs.promptui.config.export_pipeline_to_config`. The config follows schema specified in this document: https://cinnamon-ai.atlassian.net/wiki/spaces/ATM/pages/2748711193/Technical+Detail. Note: this implementation exclude the logs, which will be handled in AUR-408. - Config to UI: `kotaemon.contribs.promptui.build_from_yaml` - Example project is located at `examples/promptui/` --- .gitsecret/keys/pubring.kbx | Bin 10882 -> 8942 bytes .gitsecret/paths/mapping.cfg | 2 +- .pre-commit-config.yaml | 3 +- credentials.txt.secret | Bin 1811 -> 1848 bytes knowledgehub/contribs/promptui/base.py | 20 +++ knowledgehub/contribs/promptui/config.py | 131 +++++++++++++++ knowledgehub/contribs/promptui/ui.py | 155 +++++++++++++++++- knowledgehub/docstores/__init__.py | 2 +- .../docstores/{simple.py => in_memory.py} | 16 +- knowledgehub/documents/base.py | 15 ++ knowledgehub/llms/completions/base.py | 2 +- knowledgehub/pipelines/indexing.py | 23 ++- knowledgehub/pipelines/retrieving.py | 72 ++++++-- knowledgehub/vectorstores/base.py | 2 +- pytest.ini | 2 +- setup.py | 1 + tests/test_indexing_retrieval.py | 17 +- tests/test_promptui.py | 86 ++++++++++ 18 files changed, 503 insertions(+), 46 deletions(-) create mode 100644 knowledgehub/contribs/promptui/base.py rename knowledgehub/docstores/{simple.py => in_memory.py} (80%) create mode 100644 tests/test_promptui.py diff --git a/.gitsecret/keys/pubring.kbx b/.gitsecret/keys/pubring.kbx index 95dcf6f56642f7428120ea1923a29105e55abe92..c7ae59993acba590d7e75bd1826f37bd71886c60 100644 GIT binary patch delta 20 bcmZn)edjtsge8^hVChDQS?rtL_-&*CQ)UMt delta 56 zcmV-80LTCCMS@k36a-}L<&d!$ng;;@ld}hM1Md~dvnU7#2NaC)EdVC%EGCfAl6(@z OUNA0XdnvQ72S6Ko#uZcm diff --git a/.gitsecret/paths/mapping.cfg b/.gitsecret/paths/mapping.cfg index 576d1d7..bc9d978 100644 --- a/.gitsecret/paths/mapping.cfg +++ b/.gitsecret/paths/mapping.cfg @@ -1 +1 @@ -credentials.txt:1e17fa46dd8353b5ded588b32983ac7d800e70fd16bc5831663b9aaefc409011 +credentials.txt:272c4eb7f422bebcc5d0f1da8bde47016b185ba8cb6ca06639bb2a3e88ea9bc5 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index dd2dd90..b12ed8a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -25,7 +25,7 @@ repos: rev: 4.0.1 hooks: - id: flake8 - args: ["--max-line-length", "88"] + args: ["--max-line-length", "88", "--extend-ignore", "E203"] - repo: https://github.com/myint/autoflake rev: v1.4 hooks: @@ -47,4 +47,5 @@ repos: rev: "v1.5.1" hooks: - id: mypy + additional_dependencies: [types-PyYAML==6.0.12.11] args: ["--check-untyped-defs", "--ignore-missing-imports"] diff --git a/credentials.txt.secret b/credentials.txt.secret index b7b1e4560e799c9dc37c4b3170363ce7ebf61655..2597a44aa27236c29aa2980d4a3737f001b12a3e 100644 GIT binary patch literal 1848 zcmV-82gmq@0gMB9>i0uNB-FA23;rf3bCD6EvhhYlF|Uy?t9^G1S9`{$QB&D6fIVKp z$P*R?PwCT=Ynk7LoD+uMyX`v#c+kM`Vm&sakl9>xvs+AF%g^^fR|8x`)QzP$VL?PO zp@j5(W9S6QQ=FYwMEFaWW=ZgIFX>skG@CW0{;A_tOp% zacrNG8d9@IVkAohX>##{NKZ@8_lk;WKY_fJo!zg^BnMMnzcbhh2JRJb4q|p78OjM) zaQ?u%Vrl9eZOi^%!~KW@C;g=V#h8RmgbCndgb0aESfH+>eZI2raEFj0jirvoYXMJU zkxPB>O#1o96l%y>?RmxBvMi@8>cS+O0s|nu*X7$U=+(l^6P#)et`CI)j04I{TwAcS zM=t>k0HR183@2?qJGIQhz>KCg-=~W_laVr>XTt0B`pLTB%}!nZOWb*SkgJ3Wn8I&# z$Lw0hA>_w`bP#1bGli-?9wW92|Ll>LVD9@OO>o6Kn&CNdjBifp^IB#)-7aFCk6Zt^8BIIP6YB)N zN~Y~8q+n91^b#dE0(V?CILDUB-mVFZO~XK z`~3(A?kYq5n+2^_a1ZTY^kyA`N)vqlw@_Vfop02;&1~~37UoaXtaxD|PxzS7g)^#L z%>DM}N{-IC#J92hz?v#Hb#u?SOB`*ZBts%QeYEiYiWoo7GD3Jkr!KT%<@wH6X$(XI zj-ANy-;|zTZj)OQ&-Fv4k$!ny0ibIkGRu4>F&v1W=miOkwwOCYDOCxbWv+uMBlQqQ zBi}XpLsClHaE$FPq@lXw&QlXUU_H9Zz=Z+~1O6%n@6X$#+yM{&S!R9R$F%?Cy|_aX zWE!FtH}g_0KrUY7i&WfXdl1BuT?>`uj^(&~g$Fj7T&aeBI@L|eHM@?^+8-vAD`#y4 z^~~rcsFkh_QbkpXZNE>Knzt&3o%5;fHTbS{hijbvDcCH29A;l~YCJtocaH8{qOEI^ z2_i^3%w_%CY1Yhlo@PHH2jPpK%_~&9gt8N2=n$iFtOjhL@tWoGzHl;+qTb~(Lkoxn z{a#a^Wu%KRajW2T4C<;)(c0`4=&@Rk*;htgao&Y9TdQZ5HMz-8974fiIJw2-K`{lo z7K_~igM8s_NFaf(Dy6oMo_Qau%L8G@kL<^}fi1m5vVFhv%X60d&YNaXW6U?yUvygV z_bD({VGkkpo0vP;lTZ1H@|6>T6k@KQ$FHHVYh!p5aps!3M(DFEI6N`tfOmSP=gz;B z&%nM@YLfocNc2=(oWb#dcp<1|@QhSSDIn>NIAfLe`t1jlrQg&{L5$xTKx-g8lj9Kw zNwmL+!v=N(EhWRtAZ&ldP%i^hewInY8(|T$(JH}7rHP<8u zA#O?R2IBhXg+qC~Dmp)0^Cz2qjuCrZ$i<6XQ5BG10p{<9VyFeTcAzUS{{1(pHQpmV zT3Ie-RT5iu*K*wwziw{0acRfX0v3ucxuhWghVhfig#nBMw)PfJ7lLub0So_Yrb=M3 zztP;sald}VcwBi(#&OWAWsxN3j^XBha(Um;aE@Bp!PZTiJ?M+osxgQz4eHP(AI~L# zuRC@f@#ajkfVw!A5F?dKAuu8Fjy8mR!&%Qo0&Ftq8sQkT4$a%9NkDd9V0$n3DXHaI zhUmzpb%D@Si{sqaOB~gyQ=-dgSc^ZdJ40b3oOEcE0wkx@Qcv{|o7C0VUU)B?epJi^ zfBPYYTbBtdBMtdP-&X=aM`*7HhG2!Ny-x)2;Ek|>P`xIYWmNw&t4}yLr(-!=mx?Dp zY*im@!B)-c4b-ylMx~?J6h=uE$nq0!Dy-3RQnhS!0uMp$PW&Azsnk|EEt7F6SJ?Aq z!;FF`^Lbq|^j6nzUKw0;Y1O@1w}Mh{BuB5#Xl!|X?1+Mn&8O8XJ#mcC?{%hesCPzO z61i<-b6&VC{|5Inxwfr-P8ij}rtTa&b57$*PjBj39g6Gegq>j$7lW;8dcrp(J{qUf zd{bZSnd;GEZ($x}4{u1)dI6rDi`(3sryQgL<>Z5YA;CJ?4wc@KKe5CfK&9cenR3j4 zbYi6)6=|;o<6Nb?qTyI$yu;|5LB%^dT5AbUk!%BiBGKVwuO_AaN literal 1811 zcmV+u2kiKT0gMB9>i0uNB-FA23;>vn52z02Tc=;6Gu8xdqS&CsJ9hm8B7BxE&EoDD z2ev_tyue|mCxiQf#YZnt=3d2iW+U_Cb|#O27R;UnU#{h z5;;g1Q`@5?8woLt_@bbj?Slf!>6}p8OgVchJ(wEtWP;m6So&MBTjyS;uJj@}H)Bh1 zWe~;+iCOMG!XTrCjakm4f|S-u)G{laVduB;q}T-UupV}tHdiU8K2sidbQ{yXN3RZ* z;QGNFmyCcdn}{{DYAzuSw;HAe&#Hav?FQ^v3^EPVVepJ@3^qiqmOq(1m}Kw{qlPr| zsIGk~A?_#AUTrSR@4z)Y(JB`xi@=T>6PcK;EHp{pZ|9XZ7J#lHT6=n$f0*_g!%!7s zxs=TbzJmCzhj8uVWzx3mACt@FGbT4GxBQe1P5VE`eAKb4Lo&Un7$7y92JpIwQCtzV z_LJKZFxW`*o6e`7&nmUy1OAR^NLX61(%~AaQ2b@kqm@pChKmLKKrMvj`vrd|JW0+p6Hb`}yz(+j9)2=evRIe`LYSB2wAK+Fshg}BWEw9Wx29-9A_DVT zCUHM=XRBJ+uni06Zim5`Wyc|h|BZc-eDd+o#>mt=XWUAMJbc{z@T_t8+2*!pJM4{} zDumO1=ZR@%Dt>@6*y1+XsR!-8ZN3vt86Uwc+~!p_^lIqBi-p}n2i3S~ zs(s=PGXWt2ZnB`)a$cao130-2{aZOfZbTCE)UjD#rCe`+)hi{swqx(;{gw7ADn}X^ zJq=eoQ8Z*}))$?`fK=H-3%$u*<9Qd4wqPN1TVbUP1GLM>I5rONT(TkUZMP$!BTv<3k^TNxN*Nx#d^OdZ<-$= zcE;2TY;^a@=HCvWJ41PJhq5ww$}SKzvxNc-1O6%n@6X$#+yM{&M_AN5(BK0t7nXr7 zR7b*}qJ&F>rasn7@^mZ6595VNBF^S}zer7u#Y3zhqWeE+JTg(IUH5sW^fYGUZsKV1 zordELk&wY)Y%(xv(h^d8y0i_Pl6l`4gdh5Rb)k8%YLk0yX+nr-le^I)s{`}?5jvu}j*do>xlNoc6zPO4oUm+IM{@(3HeoK7}*j)3$vKWzMGErh=&`KE8&%X|j_ zDthGt*p;L5jw@0bkpG$`t5;eDEvp7eu_NkFXiofsZCm8`E~S212W|>azcd75_&4Ae zox`5*K|W8)m3LKyV{4FOglHYi!p_zj1x*XCCVtatg#nBMw)PfJ7lLub0So;S^dLgZ zt`Ukf6Cc>uVn|$0b^AT{+e7C4rN3e=N=?1Ed(t4`=UM><#1KaEf}<`$-y1UW6TM4S z<9xoxVKG9n8=zbJG+uu!`IQ% zqER&%Kxr22YTw$_z5lN_bDIr#@CR`QULPBh zdZRKLz^cT01pmK~+c))PKbd{VE|3U|LFt#y@%sI!N}5xVEts1X%R_3|divz3=F4Jg zt*T5?0?$KD%J;o$siUP+dJZwChI?#Xl5{tPZVc7^>yKOuKsHf3jC^1XiRchZ0%RSO z7b=dgnyFV{@BFW>qS@|Tbu`{_tU^UAh6Jv0{_|3;Y^y*w4c{7PAhe8#8^?s0jv)ld z*=$|Uc`nZRGqS^S6jp@NRRIalrRCqP_l-{N4TeL;mTd6F^|Jeh{pGQPD>BqW%X*Wu zbq@~j4i}$Kf?`6i_l?;UF1R&f{yoCmVR2tad~he}i``SkVPHHMnFL6>jj@_)b&5+= BbY1`e diff --git a/knowledgehub/contribs/promptui/base.py b/knowledgehub/contribs/promptui/base.py new file mode 100644 index 0000000..ac7c0f0 --- /dev/null +++ b/knowledgehub/contribs/promptui/base.py @@ -0,0 +1,20 @@ +import gradio as gr + +COMPONENTS_CLASS = { + "text": gr.components.Textbox, + "checkbox": gr.components.CheckboxGroup, + "dropdown": gr.components.Dropdown, + "file": gr.components.File, + "image": gr.components.Image, + "number": gr.components.Number, + "radio": gr.components.Radio, + "slider": gr.components.Slider, +} +SUPPORTED_COMPONENTS = set(COMPONENTS_CLASS.keys()) +DEFAULT_COMPONENT_BY_TYPES = { + "str": "text", + "bool": "checkbox", + "int": "number", + "float": "number", + "list": "dropdown", +} diff --git a/knowledgehub/contribs/promptui/config.py b/knowledgehub/contribs/promptui/config.py index 7e38ade..a581c75 100644 --- a/knowledgehub/contribs/promptui/config.py +++ b/knowledgehub/contribs/promptui/config.py @@ -1 +1,132 @@ """Get config from Pipeline""" +import inspect +from pathlib import Path +from typing import Any, Dict, Optional, Type, Union + +import yaml + +from ...base import BaseComponent +from .base import DEFAULT_COMPONENT_BY_TYPES + + +def config_from_value(value: Any) -> dict: + """Get the config from default value + + Args: + value (Any): default value + + Returns: + dict: config + """ + component = DEFAULT_COMPONENT_BY_TYPES.get(type(value).__name__, "text") + return { + "component": component, + "params": { + "value": value, + }, + } + + +def handle_param(param: dict) -> dict: + """Convert param definition into promptui-compliant config + + Supported gradio's UI components are (https://www.gradio.app/docs/components) + - CheckBoxGroup: list (multi select) + - DropDown: list (single select) + - File + - Image + - Number: int / float + - Radio: list (single select) + - Slider: int / float + - TextBox: str + """ + params = {} + default = param.get("default", None) + if isinstance(default, str) and default.startswith("{{") and default.endswith("}}"): + default = None + if default is not None: + params["value"] = default + + type_: str = type(default).__name__ if default is not None else "" + ui_component = DEFAULT_COMPONENT_BY_TYPES.get(type_, "text") + + return { + "component": ui_component, + "params": params, + } + + +def handle_node(node: dict) -> dict: + """Convert node definition into promptui-compliant config""" + config = {} + for name, param_def in node.get("params", {}).items(): + if isinstance(param_def["default_callback"], str): + continue + config[name] = handle_param(param_def) + for name, node_def in node.get("nodes", {}).items(): + if isinstance(node_def["default_callback"], str): + continue + for key, value in handle_node(node_def["default"]).items(): + config[f"{name}.{key}"] = value + for key, value in node_def["default_kwargs"].items(): + config[f"{name}.{key}"] = config_from_value(value) + + return config + + +def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict: + """Get the input from the pipeline""" + if not hasattr(pipeline, "run_raw"): + return {} + signature = inspect.signature(pipeline.run_raw) + inputs: Dict[str, Dict] = {} + for name, param in signature.parameters.items(): + if name in ["self", "args", "kwargs"]: + continue + input_def: Dict[str, Optional[Any]] = {"component": "text"} + default = param.default + if default is param.empty: + inputs[name] = input_def + continue + + params = {} + params["value"] = default + type_ = type(default).__name__ if default is not None else None + ui_component = None + if type_ is not None: + ui_component = "text" + + input_def["component"] = ui_component + input_def["params"] = params + + inputs[name] = input_def + + return inputs + + +def export_pipeline_to_config( + pipeline: Union[BaseComponent, Type[BaseComponent]], + path: Optional[str] = None, +) -> dict: + """Export a pipeline to a promptui-compliant config dict""" + if inspect.isclass(pipeline): + pipeline = pipeline() + + pipeline_def = pipeline.describe() + config = { + f"{pipeline.__module__}.{pipeline.__class__.__name__}": { + "params": handle_node(pipeline_def), + "inputs": handle_input(pipeline), + "outputs": [{"step": ".", "component": "text"}], + } + } + if path is not None: + old_config = config + if Path(path).is_file(): + with open(path) as f: + old_config = yaml.safe_load(f) + old_config.update(config) + with open(path, "w") as f: + yaml.safe_dump(old_config, f) + + return config diff --git a/knowledgehub/contribs/promptui/ui.py b/knowledgehub/contribs/promptui/ui.py index 912a033..2398a72 100644 --- a/knowledgehub/contribs/promptui/ui.py +++ b/knowledgehub/contribs/promptui/ui.py @@ -1,6 +1,151 @@ -"""Create UI from config file. Execute the UI from config file +from typing import Union -- Can do now: Log from stdout to UI -- In the future, we can provide some hooks and callbacks to let developers better -fine-tune the UI behavior. -""" +import gradio as gr +import yaml +from theflow.utils.modules import import_dotted_string + +from kotaemon.contribs.promptui.base import COMPONENTS_CLASS, SUPPORTED_COMPONENTS + +USAGE_INSTRUCTION = """In case of errors, you can: + +- Create bug fix and make PR at: https://github.com/Cinnamon/kotaemon +- Ping any of @john @tadashi @ian @jacky in Slack channel #llm-productization""" + + +def get_component(component_def: dict) -> gr.components.Component: + """Get the component based on component definition""" + component_cls = None + + if "component" in component_def: + component = component_def["component"] + if component not in SUPPORTED_COMPONENTS: + raise ValueError( + f"Unsupported UI component: {component}. " + f"Must be one of {SUPPORTED_COMPONENTS}" + ) + + component_cls = COMPONENTS_CLASS[component] + else: + raise ValueError( + f"Cannot decide the component from {component_def}. " + "Please specify `component` with 1 of the following " + f"values: {SUPPORTED_COMPONENTS}" + ) + + return component_cls(**component_def.get("params", {})) + + +def construct_ui(config, func_run, func_export) -> gr.Blocks: + """Create UI from config file. Execute the UI from config file + + - Can do now: Log from stdout to UI + - In the future, we can provide some hooks and callbacks to let developers better + fine-tune the UI behavior. + """ + inputs, outputs, params = [], [], [] + for name, component_def in config.get("inputs", {}).items(): + if "params" not in component_def: + component_def["params"] = {} + component_def["params"]["interactive"] = True + component = get_component(component_def) + if hasattr(component, "label") and not component.label: # type: ignore + component.label = name # type: ignore + + inputs.append(component) + + for name, component_def in config.get("params", {}).items(): + if "params" not in component_def: + component_def["params"] = {} + component_def["params"]["interactive"] = True + component = get_component(component_def) + if hasattr(component, "label") and not component.label: # type: ignore + component.label = name # type: ignore + + params.append(component) + + for idx, component_def in enumerate(config.get("outputs", [])): + if "params" not in component_def: + component_def["params"] = {} + component_def["params"]["interactive"] = False + component = get_component(component_def) + if hasattr(component, "label") and not component.label: # type: ignore + component.label = f"Output {idx}" + + outputs.append(component) + + temp = gr.Tab + with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo: + with gr.Accordion(label="Usage", open=False): + gr.Markdown(USAGE_INSTRUCTION) + with gr.Row(): + run_btn = gr.Button("Run") + run_btn.click(func_run, inputs=inputs + params, outputs=outputs) + export_btn = gr.Button("Export") + export_btn.click(func_export, inputs=None, outputs=None) + with gr.Row(): + with gr.Column(): + with temp("Inputs"): + for component in inputs: + component.render() + with temp("Params"): + for component in params: + component.render() + with gr.Column(): + for component in outputs: + component.render() + + return demo + + +def build_pipeline_ui(config: dict, pipeline_def): + """Build a tab from config file""" + inputs_name = list(config.get("inputs", {}).keys()) + params_name = list(config.get("params", {}).keys()) + outputs_def = config.get("outputs", []) + + def run_func(*args): + inputs = { + name: value for name, value in zip(inputs_name, args[: len(inputs_name)]) + } + params = { + name: value for name, value in zip(params_name, args[len(inputs_name) :]) + } + pipeline = pipeline_def() + pipeline.set(params) + pipeline(**inputs) + if outputs_def: + outputs = [] + for output_def in outputs_def: + output = pipeline.last_run.logs(output_def["step"]) + if "item" in output_def: + output = output[output_def["item"]] + outputs.append(output) + return outputs + + # TODO: export_func is None for now + return construct_ui(config, run_func, None) + + +def build_from_dict(config: Union[str, dict]): + """Build a full UI from YAML config file""" + + if isinstance(config, str): + with open(config) as f: + config_dict: dict = yaml.safe_load(f) + elif isinstance(config, dict): + config_dict = config + else: + raise ValueError( + f"config must be either a yaml path or a dict, got {type(config)}" + ) + + demos = [] + for key, value in config_dict.items(): + pipeline_def = import_dotted_string(key, safe=False) + demos.append(build_pipeline_ui(value, pipeline_def)) + if len(demos) == 1: + demo = demos[0] + else: + demo = gr.TabbedInterface(demos, list(config_dict.keys())) + + return demo diff --git a/knowledgehub/docstores/__init__.py b/knowledgehub/docstores/__init__.py index 88f6829..bee4fc5 100644 --- a/knowledgehub/docstores/__init__.py +++ b/knowledgehub/docstores/__init__.py @@ -1,4 +1,4 @@ from .base import BaseDocumentStore -from .simple import InMemoryDocumentStore +from .in_memory import InMemoryDocumentStore __all__ = ["BaseDocumentStore", "InMemoryDocumentStore"] diff --git a/knowledgehub/docstores/simple.py b/knowledgehub/docstores/in_memory.py similarity index 80% rename from knowledgehub/docstores/simple.py rename to knowledgehub/docstores/in_memory.py index 7f812e8..577363e 100644 --- a/knowledgehub/docstores/simple.py +++ b/knowledgehub/docstores/in_memory.py @@ -10,7 +10,7 @@ class InMemoryDocumentStore(BaseDocumentStore): """Simple memory document store that store document in a dictionary""" def __init__(self): - self.store = {} + self._store = {} def add( self, @@ -32,20 +32,20 @@ class InMemoryDocumentStore(BaseDocumentStore): docs = [docs] for doc_id, doc in zip(doc_ids, docs): - if doc_id in self.store and not exist_ok: + if doc_id in self._store and not exist_ok: raise ValueError(f"Document with id {doc_id} already exist") - self.store[doc_id] = doc + self._store[doc_id] = doc def get(self, ids: Union[List[str], str]) -> List[Document]: """Get document by id""" if not isinstance(ids, list): ids = [ids] - return [self.store[doc_id] for doc_id in ids] + return [self._store[doc_id] for doc_id in ids] def get_all(self) -> dict: """Get all documents""" - return self.store + return self._store def delete(self, ids: Union[List[str], str]): """Delete document by id""" @@ -53,11 +53,11 @@ class InMemoryDocumentStore(BaseDocumentStore): ids = [ids] for doc_id in ids: - del self.store[doc_id] + del self._store[doc_id] def save(self, path: Union[str, Path]): """Save document to path""" - store = {key: value.to_dict() for key, value in self.store.items()} + store = {key: value.to_dict() for key, value in self._store.items()} with open(path, "w") as f: json.dump(store, f) @@ -65,4 +65,4 @@ class InMemoryDocumentStore(BaseDocumentStore): """Load document store from path""" with open(path) as f: store = json.load(f) - self.store = {key: Document.from_dict(value) for key, value in store.items()} + self._store = {key: Document.from_dict(value) for key, value in store.items()} diff --git a/knowledgehub/documents/base.py b/knowledgehub/documents/base.py index 29bbcec..30f540d 100644 --- a/knowledgehub/documents/base.py +++ b/knowledgehub/documents/base.py @@ -1,4 +1,5 @@ from haystack.schema import Document as HaystackDocument +from llama_index.bridge.pydantic import Field from llama_index.schema import Document as BaseDocument SAMPLE_TEXT = "A sample Document from kotaemon" @@ -20,3 +21,17 @@ class Document(BaseDocument): metadata = self.metadata or {} text = self.text return HaystackDocument(content=text, meta=metadata) + + +class RetrievedDocument(Document): + """Subclass of Document with retrieval-related information + + Attributes: + score (float): score of the document (from 0.0 to 1.0) + retrieval_metadata (dict): metadata from the retrieval process, can be used + by different components in a retrieved pipeline to communicate with each + other + """ + + score: float = Field(default=0.0) + retrieval_metadata: dict = Field(default={}) diff --git a/knowledgehub/llms/completions/base.py b/knowledgehub/llms/completions/base.py index 6cbba52..0b03d72 100644 --- a/knowledgehub/llms/completions/base.py +++ b/knowledgehub/llms/completions/base.py @@ -27,7 +27,7 @@ class LangchainLLM(LLM): self._kwargs[param] = params.pop(param) super().__init__(**params) - @Param.decorate() + @Param.decorate(no_cache=True) def agent(self): return self._lc_class(**self._kwargs) diff --git a/knowledgehub/pipelines/indexing.py b/knowledgehub/pipelines/indexing.py index a3b0c72..1935268 100644 --- a/knowledgehub/pipelines/indexing.py +++ b/knowledgehub/pipelines/indexing.py @@ -1,8 +1,10 @@ -from typing import List +import uuid +from typing import List, Optional from theflow import Node, Param from ..base import BaseComponent +from ..docstores import BaseDocumentStore from ..documents.base import Document from ..embeddings import BaseEmbeddings from ..vectorstores import BaseVectorStore @@ -18,21 +20,30 @@ class IndexVectorStoreFromDocumentPipeline(BaseComponent): """ vector_store: Param[BaseVectorStore] = Param() + doc_store: Optional[BaseDocumentStore] = None embedding: Node[BaseEmbeddings] = Node() - # TODO: populate to document store as well when it's finished + # TODO: refer to llama_index's storage as well def run_raw(self, text: str) -> None: - self.vector_store.add([self.embedding(text)]) + document = Document(text=text, id_=str(uuid.uuid4())) + self.run_batch_document([document]) def run_batch_raw(self, text: List[str]) -> None: - self.vector_store.add(self.embedding(text)) + documents = [Document(t, id_=str(uuid.uuid4())) for t in text] + self.run_batch_document(documents) def run_document(self, text: Document) -> None: - self.vector_store.add([self.embedding(text)]) + self.run_batch_document([text]) def run_batch_document(self, text: List[Document]) -> None: - self.vector_store.add(self.embedding(text)) + embeddings = self.embedding(text) + self.vector_store.add( + embeddings=embeddings, + ids=[t.id_ for t in text], + ) + if self.doc_store: + self.doc_store.add(text) def is_document(self, text) -> bool: if isinstance(text, Document): diff --git a/knowledgehub/pipelines/retrieving.py b/knowledgehub/pipelines/retrieving.py index 4ba43ba..64d4e6b 100644 --- a/knowledgehub/pipelines/retrieving.py +++ b/knowledgehub/pipelines/retrieving.py @@ -1,47 +1,87 @@ -from typing import List +from abc import abstractmethod +from typing import List, Optional from theflow import Node, Param from ..base import BaseComponent -from ..documents.base import Document +from ..docstores import BaseDocumentStore +from ..documents.base import Document, RetrievedDocument from ..embeddings import BaseEmbeddings from ..vectorstores import BaseVectorStore -class RetrieveDocumentFromVectorStorePipeline(BaseComponent): +class BaseRetrieval(BaseComponent): + """Define the base interface of a retrieval pipeline""" + + @abstractmethod + def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]: + ... + + @abstractmethod + def run_batch_raw( + self, text: List[str], top_k: int = 1 + ) -> List[List[RetrievedDocument]]: + ... + + @abstractmethod + def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]: + ... + + @abstractmethod + def run_batch_document( + self, text: List[Document], top_k: int = 1 + ) -> List[List[RetrievedDocument]]: + ... + + +class RetrieveDocumentFromVectorStorePipeline(BaseRetrieval): """Retrieve list of documents from vector store""" vector_store: Param[BaseVectorStore] = Param() + doc_store: Optional[BaseDocumentStore] = None embedding: Node[BaseEmbeddings] = Node() - # TODO: populate to document store as well when it's finished # TODO: refer to llama_index's storage as well - def run_raw(self, text: str) -> List[str]: - emb = self.embedding(text) - return self.vector_store.query(embedding=emb)[2] + def run_raw(self, text: str, top_k: int = 1) -> List[RetrievedDocument]: + return self.run_batch_raw([text], top_k=top_k)[0] + + def run_batch_raw( + self, text: List[str], top_k: int = 1 + ) -> List[List[RetrievedDocument]]: + if self.doc_store is None: + raise ValueError( + "doc_store is not provided. Please provide a doc_store to " + "retrieve the documents" + ) - def run_batch_raw(self, text: List[str]) -> List[List[str]]: result = [] for each_text in text: emb = self.embedding(each_text) - result.append(self.vector_store.query(embedding=emb)[2]) + _, scores, ids = self.vector_store.query(embedding=emb, top_k=top_k) + docs = self.doc_store.get(ids) + each_result = [ + RetrievedDocument(**doc.to_dict(), score=score) + for doc, score in zip(docs, scores) + ] + result.append(each_result) return result - def run_document(self, text: Document) -> List[str]: - return self.run_raw(text.text) + def run_document(self, text: Document, top_k: int = 1) -> List[RetrievedDocument]: + return self.run_raw(text.text, top_k) - def run_batch_document(self, text: List[Document]) -> List[List[str]]: - input_text = [each.text for each in text] - return self.run_batch_raw(input_text) + def run_batch_document( + self, text: List[Document], top_k: int = 1 + ) -> List[List[RetrievedDocument]]: + return self.run_batch_raw(text=[t.text for t in text], top_k=top_k) - def is_document(self, text) -> bool: + def is_document(self, text, *args, **kwargs) -> bool: if isinstance(text, Document): return True elif isinstance(text, List) and isinstance(text[0], Document): return True return False - def is_batch(self, text) -> bool: + def is_batch(self, text, *args, **kwargs) -> bool: if isinstance(text, list): return True return False diff --git a/knowledgehub/vectorstores/base.py b/knowledgehub/vectorstores/base.py index ac66e6b..310c019 100644 --- a/knowledgehub/vectorstores/base.py +++ b/knowledgehub/vectorstores/base.py @@ -144,8 +144,8 @@ class LlamaIndexVectorStore(BaseVectorStore): query_embedding=embedding, similarity_top_k=top_k, node_ids=ids, + **kwargs, ), - **kwargs, ) embeddings = [] diff --git a/pytest.ini b/pytest.ini index 1127b02..819c6b0 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,7 +3,7 @@ minversion = 7.4.0 testpaths = tests addopts = -ra -q log_cli=true -log_level=DEBUG +log_level=WARNING log_format = %(asctime)s %(levelname)s %(message)s log_date_format = %Y-%m-%d %H:%M:%S log_file = logs/pytest-logs.txt diff --git a/setup.py b/setup.py index 91e9b67..819d363 100644 --- a/setup.py +++ b/setup.py @@ -34,6 +34,7 @@ setuptools.setup( "llama-index", "llama-hub", "nltk", + "gradio", ], extras_require={ "dev": [ diff --git a/tests/test_indexing_retrieval.py b/tests/test_indexing_retrieval.py index 1dc3377..c253c25 100644 --- a/tests/test_indexing_retrieval.py +++ b/tests/test_indexing_retrieval.py @@ -1,9 +1,11 @@ import json from pathlib import Path +from typing import cast import pytest from openai.api_resources.embedding import Embedding +from kotaemon.docstores import InMemoryDocumentStore from kotaemon.documents.base import Document from kotaemon.embeddings.openai import AzureOpenAIEmbeddings from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline @@ -21,6 +23,7 @@ def mock_openai_embedding(monkeypatch): def test_indexing(mock_openai_embedding, tmp_path): db = ChromaVectorStore(path=str(tmp_path)) + doc_store = InMemoryDocumentStore() embedding = AzureOpenAIEmbeddings( model="text-embedding-ada-002", deployment="embedding-deployment", @@ -29,15 +32,19 @@ def test_indexing(mock_openai_embedding, tmp_path): ) pipeline = IndexVectorStoreFromDocumentPipeline( - vector_store=db, embedding=embedding + vector_store=db, embedding=embedding, doc_store=doc_store ) + pipeline.doc_store = cast(InMemoryDocumentStore, pipeline.doc_store) assert pipeline.vector_store._collection.count() == 0, "Expected empty collection" + assert len(pipeline.doc_store._store) == 0, "Expected empty doc store" pipeline(text=Document(text="Hello world")) assert pipeline.vector_store._collection.count() == 1, "Index 1 item" + assert len(pipeline.doc_store._store) == 1, "Expected 1 document" def test_retrieving(mock_openai_embedding, tmp_path): db = ChromaVectorStore(path=str(tmp_path)) + doc_store = InMemoryDocumentStore() embedding = AzureOpenAIEmbeddings( model="text-embedding-ada-002", deployment="embedding-deployment", @@ -46,14 +53,14 @@ def test_retrieving(mock_openai_embedding, tmp_path): ) index_pipeline = IndexVectorStoreFromDocumentPipeline( - vector_store=db, embedding=embedding + vector_store=db, embedding=embedding, doc_store=doc_store ) retrieval_pipeline = RetrieveDocumentFromVectorStorePipeline( - vector_store=db, embedding=embedding + vector_store=db, doc_store=doc_store, embedding=embedding ) index_pipeline(text=Document(text="Hello world")) output = retrieval_pipeline(text=["Hello world", "Hello world"]) - assert len(output) == 2, "Expected 2 results" - assert output[0] == output[1], "Expected identical results" + assert len(output) == 2, "Expect 2 results" + assert output[0] == output[1], "Expect identical results" diff --git a/tests/test_promptui.py b/tests/test_promptui.py new file mode 100644 index 0000000..74468e9 --- /dev/null +++ b/tests/test_promptui.py @@ -0,0 +1,86 @@ +import pytest + +from kotaemon.contribs.promptui.config import export_pipeline_to_config +from kotaemon.contribs.promptui.ui import build_from_dict + + +@pytest.fixture() +def simple_pipeline_cls(tmp_path): + """Create a pipeline class that can be used""" + from typing import List + + from theflow import Node + + from kotaemon.base import BaseComponent + from kotaemon.embeddings import AzureOpenAIEmbeddings + from kotaemon.llms.completions.openai import AzureOpenAI + from kotaemon.pipelines.retrieving import ( + RetrieveDocumentFromVectorStorePipeline, + ) + from kotaemon.vectorstores import ChromaVectorStore + + class Pipeline(BaseComponent): + vectorstore_path: str = str(tmp_path) + llm: Node[AzureOpenAI] = Node( + default=AzureOpenAI, + default_kwargs={ + "openai_api_base": "https://test.openai.azure.com/", + "openai_api_key": "some-key", + "openai_api_version": "2023-03-15-preview", + "deployment_name": "gpt35turbo", + "temperature": 0, + "request_timeout": 60, + }, + ) + + @Node.decorate(depends_on=["vectorstore_path"]) + def retrieving_pipeline(self): + vector_store = ChromaVectorStore(self.vectorstore_path) + embedding = AzureOpenAIEmbeddings( + model="text-embedding-ada-002", + deployment="embedding-deployment", + openai_api_base="https://test.openai.azure.com/", + openai_api_key="some-key", + ) + + return RetrieveDocumentFromVectorStorePipeline( + vector_store=vector_store, embedding=embedding + ) + + def run_raw(self, text: str) -> str: + matched_texts: List[str] = self.retrieving_pipeline(text) + return self.llm("\n".join(matched_texts)).text[0] + + return Pipeline + + +Pipeline = simple_pipeline_cls + + +class TestPromptConfig: + def test_export_prompt_config(self, simple_pipeline_cls): + """Test if the prompt config is exported correctly""" + pipeline = simple_pipeline_cls() + config_dict = export_pipeline_to_config(pipeline) + config = list(config_dict.values())[0] + + assert "inputs" in config, "inputs should be in config" + assert "text" in config["inputs"], "inputs should have config" + + assert "params" in config, "params should be in config" + assert "vectorstore_path" in config["params"] + assert "llm.deployment_name" in config["params"] + assert "llm.openai_api_base" in config["params"] + assert "llm.openai_api_key" in config["params"] + assert "llm.openai_api_version" in config["params"] + assert "llm.request_timeout" in config["params"] + assert "llm.temperature" in config["params"] + + +class TestPromptUI: + def test_uigeneration(self, simple_pipeline_cls): + """Test if the gradio UI is exposed without any problem""" + pipeline = simple_pipeline_cls() + config = export_pipeline_to_config(pipeline) + + build_from_dict(config)