Upgrade the declarative pipeline for cleaner interface (#51)
This commit is contained in:
parent
aab982ddc4
commit
9035e25666
|
@ -22,4 +22,4 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "0.0.4"
|
||||
__version__ = "0.2.0"
|
||||
|
|
|
@ -10,4 +10,4 @@ class SimpleRespondentChatbot(BaseChatBot):
|
|||
llm: Node[ChatLLM]
|
||||
|
||||
def _get_message(self) -> str:
|
||||
return self.llm(self.history).text[0]
|
||||
return self.llm(self.history).text
|
||||
|
|
|
@ -63,19 +63,19 @@ def handle_node(node: dict) -> dict:
|
|||
"""Convert node definition into promptui-compliant config"""
|
||||
config = {}
|
||||
for name, param_def in node.get("params", {}).items():
|
||||
if isinstance(param_def["default_callback"], str):
|
||||
if isinstance(param_def["auto_callback"], str):
|
||||
continue
|
||||
if param_def.get("ignore_ui", False):
|
||||
continue
|
||||
config[name] = handle_param(param_def)
|
||||
for name, node_def in node.get("nodes", {}).items():
|
||||
if isinstance(node_def["default_callback"], str):
|
||||
if isinstance(node_def["auto_callback"], str):
|
||||
continue
|
||||
if node_def.get("ignore_ui", False):
|
||||
continue
|
||||
for key, value in handle_node(node_def["default"]).items():
|
||||
config[f"{name}.{key}"] = value
|
||||
for key, value in node_def["default_kwargs"].items():
|
||||
for key, value in node_def.get("default_kwargs", {}).items():
|
||||
config[f"{name}.{key}"] = config_from_value(value)
|
||||
|
||||
return config
|
||||
|
@ -124,11 +124,14 @@ def export_pipeline_to_config(
|
|||
if ui_type == "chat":
|
||||
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
|
||||
params["system_message"] = {"component": "text", "params": {"value": ""}}
|
||||
outputs = []
|
||||
if hasattr(pipeline, "_promptui_outputs"):
|
||||
outputs = pipeline._promptui_outputs
|
||||
config_obj: dict = {
|
||||
"ui-type": ui_type,
|
||||
"params": params,
|
||||
"inputs": {},
|
||||
"outputs": [],
|
||||
"outputs": outputs,
|
||||
"logs": {
|
||||
"full_pipeline": {
|
||||
"input": {
|
||||
|
|
|
@ -61,6 +61,9 @@ def from_log_to_dict(pipeline_cls: Type[BaseComponent], log_config: dict) -> dic
|
|||
if name not in logged_infos:
|
||||
logged_infos[name] = [None] * len(dirs)
|
||||
|
||||
if step not in progress:
|
||||
continue
|
||||
|
||||
info = progress[step]
|
||||
if getter:
|
||||
if getter in allowed_resultlog_callbacks:
|
||||
|
|
|
@ -13,9 +13,9 @@ class John(Base):
|
|||
primary_hue: colors.Color | str = colors.neutral,
|
||||
secondary_hue: colors.Color | str = colors.neutral,
|
||||
neutral_hue: colors.Color | str = colors.neutral,
|
||||
spacing_size: sizes.Size | str = sizes.spacing_lg,
|
||||
spacing_size: sizes.Size | str = sizes.spacing_sm,
|
||||
radius_size: sizes.Size | str = sizes.radius_none,
|
||||
text_size: sizes.Size | str = sizes.text_md,
|
||||
text_size: sizes.Size | str = sizes.text_sm,
|
||||
font: fonts.Font
|
||||
| str
|
||||
| Iterable[fonts.Font | str] = (
|
||||
|
@ -79,8 +79,8 @@ class John(Base):
|
|||
button_cancel_background_fill_hover="*button_primary_background_fill_hover",
|
||||
button_cancel_text_color="*button_primary_text_color",
|
||||
# Padding
|
||||
checkbox_label_padding="*spacing_md",
|
||||
button_large_padding="*spacing_lg",
|
||||
checkbox_label_padding="*spacing_sm",
|
||||
button_large_padding="*spacing_sm",
|
||||
button_small_padding="*spacing_sm",
|
||||
# Borders
|
||||
block_border_width="0px",
|
||||
|
@ -91,5 +91,5 @@ class John(Base):
|
|||
# Block Labels
|
||||
block_title_text_weight="600",
|
||||
block_label_text_weight="600",
|
||||
block_label_text_size="*text_md",
|
||||
block_label_text_size="*text_sm",
|
||||
)
|
||||
|
|
|
@ -26,9 +26,9 @@ def build_from_dict(config: Union[str, dict]):
|
|||
for key, value in config_dict.items():
|
||||
pipeline_def = import_dotted_string(key, safe=False)
|
||||
if value["ui-type"] == "chat":
|
||||
demos.append(build_chat_ui(value, pipeline_def))
|
||||
demos.append(build_chat_ui(value, pipeline_def).queue())
|
||||
else:
|
||||
demos.append(build_pipeline_ui(value, pipeline_def))
|
||||
demos.append(build_pipeline_ui(value, pipeline_def).queue())
|
||||
if len(demos) == 1:
|
||||
demo = demos[0]
|
||||
else:
|
||||
|
|
181
knowledgehub/contribs/promptui/ui/blocks.py
Normal file
181
knowledgehub/contribs/promptui/ui/blocks.py
Normal file
|
@ -0,0 +1,181 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
import anyio
|
||||
from gradio import ChatInterface
|
||||
from gradio.components import IOComponent, get_component_instance
|
||||
from gradio.events import on
|
||||
from gradio.helpers import special_args
|
||||
from gradio.routes import Request
|
||||
|
||||
|
||||
class ChatBlock(ChatInterface):
|
||||
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
|
||||
|
||||
- Show additional outputs to the chat interface
|
||||
- Disallow blank user message
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if additional_outputs:
|
||||
if not isinstance(additional_outputs, list):
|
||||
additional_outputs = [additional_outputs]
|
||||
self.additional_outputs = [
|
||||
get_component_instance(i) for i in additional_outputs # type: ignore
|
||||
]
|
||||
else:
|
||||
self.additional_outputs = []
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
async def _submit_fn(
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
request: Request,
|
||||
*args,
|
||||
) -> tuple[Any, ...]:
|
||||
input_args = args[: -len(self.additional_outputs)]
|
||||
output_args = args[-len(self.additional_outputs) :]
|
||||
if not message:
|
||||
return history_with_input, history_with_input, *output_args
|
||||
|
||||
history = history_with_input[:-1]
|
||||
inputs, _, _ = special_args(
|
||||
self.fn, inputs=[message, history, *input_args], request=request
|
||||
)
|
||||
|
||||
if self.is_async:
|
||||
response = await self.fn(*inputs)
|
||||
else:
|
||||
response = await anyio.to_thread.run_sync(
|
||||
self.fn, *inputs, limiter=self.limiter
|
||||
)
|
||||
|
||||
output = []
|
||||
if self.additional_outputs:
|
||||
text = response[0]
|
||||
output = response[1:]
|
||||
else:
|
||||
text = response
|
||||
|
||||
history.append([message, text])
|
||||
return history, history, *output
|
||||
|
||||
async def _stream_fn(
|
||||
self,
|
||||
message: str,
|
||||
history_with_input: list[list[str | None]],
|
||||
*args,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError("Stream function not implemented for ChatBlock")
|
||||
|
||||
def _display_input(
|
||||
self, message: str, history: list[list[str | None]]
|
||||
) -> tuple[list[list[str | None]], list[list[str | None]]]:
|
||||
"""Stop displaying the input message if the message is a blank string"""
|
||||
if not message:
|
||||
return history, history
|
||||
return super()._display_input(message, history)
|
||||
|
||||
def _setup_events(self) -> None:
|
||||
"""Include additional outputs in the submit event"""
|
||||
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
|
||||
submit_triggers = (
|
||||
[self.textbox.submit, self.submit_btn.click]
|
||||
if self.submit_btn
|
||||
else [self.textbox.submit]
|
||||
)
|
||||
submit_event = (
|
||||
on(
|
||||
submit_triggers,
|
||||
self._clear_and_save_textbox,
|
||||
[self.textbox],
|
||||
[self.textbox, self.saved_input],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state]
|
||||
+ self.additional_inputs
|
||||
+ self.additional_outputs,
|
||||
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events(submit_triggers, submit_event)
|
||||
|
||||
if self.retry_btn:
|
||||
retry_event = (
|
||||
self.retry_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
self._display_input,
|
||||
[self.saved_input, self.chatbot_state],
|
||||
[self.chatbot, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
.then(
|
||||
submit_fn,
|
||||
[self.saved_input, self.chatbot_state]
|
||||
+ self.additional_inputs
|
||||
+ self.additional_outputs,
|
||||
[self.chatbot, self.chatbot_state] + self.additional_outputs,
|
||||
api_name=False,
|
||||
)
|
||||
)
|
||||
self._setup_stop_events([self.retry_btn.click], retry_event)
|
||||
|
||||
if self.undo_btn:
|
||||
self.undo_btn.click(
|
||||
self._delete_prev_fn,
|
||||
[self.chatbot_state],
|
||||
[self.chatbot, self.saved_input, self.chatbot_state],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
).then(
|
||||
lambda x: x,
|
||||
[self.saved_input],
|
||||
[self.textbox],
|
||||
api_name=False,
|
||||
queue=False,
|
||||
)
|
||||
|
||||
if self.clear_btn:
|
||||
self.clear_btn.click(
|
||||
lambda: ([], [], None),
|
||||
None,
|
||||
[self.chatbot, self.chatbot_state, self.saved_input],
|
||||
queue=False,
|
||||
api_name=False,
|
||||
)
|
||||
|
||||
def _setup_api(self) -> None:
|
||||
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
|
||||
|
||||
self.fake_api_btn.click(
|
||||
api_fn,
|
||||
[self.textbox, self.chatbot_state] + self.additional_inputs,
|
||||
[self.textbox, self.chatbot_state] + self.additional_outputs,
|
||||
api_name="chat",
|
||||
)
|
|
@ -8,6 +8,9 @@ from theflow.storage import storage
|
|||
from kotaemon.chatbot import ChatConversation
|
||||
from kotaemon.contribs.promptui.base import get_component
|
||||
from kotaemon.contribs.promptui.export import export
|
||||
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
|
||||
|
||||
from ..logs import ResultLog
|
||||
|
||||
USAGE_INSTRUCTION = """## How to use:
|
||||
|
||||
|
@ -87,8 +90,10 @@ def construct_chat_ui(
|
|||
outputs.append(component)
|
||||
|
||||
sess = gr.State(value=None)
|
||||
chatbot = gr.Chatbot(label="Chatbot")
|
||||
chat = gr.ChatInterface(func_chat, chatbot=chatbot, additional_inputs=[sess])
|
||||
chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
|
||||
chat = ChatBlock(
|
||||
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
|
||||
)
|
||||
param_state = gr.Textbox(interactive=False)
|
||||
|
||||
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
|
||||
|
@ -106,6 +111,7 @@ def construct_chat_ui(
|
|||
chat.saved_input,
|
||||
param_state,
|
||||
sess,
|
||||
*outputs,
|
||||
],
|
||||
)
|
||||
with gr.Accordion(label="End chat", open=False):
|
||||
|
@ -162,6 +168,9 @@ def build_chat_ui(config, pipeline_def):
|
|||
exported_dir = output_dir.parent / "exported"
|
||||
exported_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
|
||||
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
|
||||
|
||||
def new_chat(*args):
|
||||
"""Start a new chat function
|
||||
|
||||
|
@ -190,7 +199,14 @@ def build_chat_ui(config, pipeline_def):
|
|||
)
|
||||
|
||||
gr.Info("New chat session started.")
|
||||
return [], [], None, param_state_str, session
|
||||
return (
|
||||
[],
|
||||
[],
|
||||
None,
|
||||
param_state_str,
|
||||
session,
|
||||
*[None] * len(config.get("outputs", [])),
|
||||
)
|
||||
|
||||
def chat(message, history, session, *args):
|
||||
"""The chat interface
|
||||
|
@ -212,7 +228,18 @@ def build_chat_ui(config, pipeline_def):
|
|||
"No active chat session. Please set the params and click New chat"
|
||||
)
|
||||
|
||||
return session(message).content
|
||||
pred = session(message)
|
||||
text_response = pred.content
|
||||
|
||||
additional_outputs = []
|
||||
for output_def in config.get("outputs", []):
|
||||
value = session.last_run.logs(output_def["step"])
|
||||
getter = output_def.get("getter", None)
|
||||
if getter and getter in allowed_resultlog_callbacks:
|
||||
value = getattr(resultlog, getter)(value)
|
||||
additional_outputs.append(value)
|
||||
|
||||
return text_response, *additional_outputs
|
||||
|
||||
def end_chat(preference: str, save_log: bool, session):
|
||||
"""End the chat session
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, TypeVar
|
||||
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
from llama_index.bridge.pydantic import Field
|
||||
from llama_index.schema import Document as BaseDocument
|
||||
from pyparsing import TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
IO_Type = TypeVar("IO_Type", "Document", str)
|
||||
SAMPLE_TEXT = "A sample Document from kotaemon"
|
||||
|
@ -49,6 +50,8 @@ class Document(BaseDocument):
|
|||
|
||||
def to_haystack_format(self) -> "HaystackDocument":
|
||||
"""Convert struct to Haystack document format."""
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
metadata = self.metadata or {}
|
||||
text = self.text
|
||||
return HaystackDocument(content=text, meta=metadata)
|
||||
|
|
|
@ -56,11 +56,11 @@ class LangchainEmbeddings(BaseEmbeddings):
|
|||
|
||||
def __setattr__(self, name, value):
|
||||
if name in self._lc_class.__fields__:
|
||||
setattr(self.agent, name, value)
|
||||
self._kwargs[name] = value
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
@Param.decorate(no_cache=True)
|
||||
@Param.auto(cache=False)
|
||||
def agent(self):
|
||||
return self._lc_class(**self._kwargs)
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from kotaemon.documents.base import Document
|
|||
|
||||
|
||||
class LLMInterface(Document):
|
||||
candidates: List[str]
|
||||
candidates: List[str] = Field(default_factory=list)
|
||||
completion_tokens: int = -1
|
||||
total_tokens: int = -1
|
||||
prompt_tokens: int = -1
|
||||
|
|
|
@ -40,7 +40,7 @@ class LangchainChatLLM(ChatLLM):
|
|||
self._kwargs[param] = params.pop(param)
|
||||
super().__init__(**params)
|
||||
|
||||
@Param.decorate(no_cache=True)
|
||||
@Param.auto(cache=False)
|
||||
def agent(self) -> BaseLanguageModel:
|
||||
return self._lc_class(**self._kwargs)
|
||||
|
||||
|
@ -92,3 +92,9 @@ class LangchainChatLLM(ChatLLM):
|
|||
setattr(self.agent, name, value)
|
||||
else:
|
||||
super().__setattr__(name, value)
|
||||
|
||||
def __getattr__(self, name):
|
||||
if name in self._lc_class.__fields__:
|
||||
getattr(self.agent, name)
|
||||
else:
|
||||
super().__getattr__(name)
|
||||
|
|
|
@ -27,7 +27,7 @@ class LangchainLLM(LLM):
|
|||
self._kwargs[param] = params.pop(param)
|
||||
super().__init__(**params)
|
||||
|
||||
@Param.decorate(no_cache=True)
|
||||
@Param.auto(cache=False)
|
||||
def agent(self):
|
||||
return self._lc_class(**self._kwargs)
|
||||
|
||||
|
|
|
@ -69,8 +69,8 @@ class Thought(BaseComponent):
|
|||
"variable placeholders, that then will be subsituted with real values when "
|
||||
"this component is executed"
|
||||
)
|
||||
llm = Node(
|
||||
default=AzureChatOpenAI, help="The LLM model to execute the input prompt"
|
||||
llm: Node[BaseComponent] = Node(
|
||||
AzureChatOpenAI, help="The LLM model to execute the input prompt"
|
||||
)
|
||||
post_process: Node[Compose] = Node(
|
||||
help="The function post-processor that post-processes LLM output prediction ."
|
||||
|
@ -78,7 +78,7 @@ class Thought(BaseComponent):
|
|||
"a dictionary, where the key should"
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on="prompt")
|
||||
@Node.auto(depends_on="prompt")
|
||||
def prompt_template(self):
|
||||
"""Automatically wrap around param prompt. Can ignore"""
|
||||
return BasePromptComponent(self.prompt)
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from theflow import Node, Param
|
||||
from llama_index.readers.base import BaseReader
|
||||
from theflow import Node
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
|
@ -32,33 +34,22 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
# Expose variables for users to switch in prompt ui
|
||||
storage_path: Path = Path("./storage")
|
||||
reader_name: str = "normal" # "normal" or "mathpix"
|
||||
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
|
||||
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
chunk_size: int = 1024
|
||||
chunk_overlap: int = 256
|
||||
file_name_list: List[str] = list()
|
||||
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
|
||||
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
|
||||
|
||||
@Param.decorate()
|
||||
def vector_store(self):
|
||||
return InMemoryVectorStore()
|
||||
|
||||
@Param.decorate()
|
||||
def doc_store(self):
|
||||
doc_store = InMemoryDocumentStore()
|
||||
return doc_store
|
||||
|
||||
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
|
||||
def embedding(self):
|
||||
return AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base=self.openai_api_base,
|
||||
openai_api_key=self.openai_api_key,
|
||||
)
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
)
|
||||
|
||||
def get_reader(self, input_files: List[Union[str, Path]]):
|
||||
# document parsers
|
||||
file_extractor = {
|
||||
file_extractor: Dict[str, BaseReader] = {
|
||||
".xlsx": PandasExcelReader(),
|
||||
}
|
||||
if self.reader_name == "normal":
|
||||
|
@ -71,7 +62,7 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
)
|
||||
return main_reader
|
||||
|
||||
@Node.decorate(depends_on=["doc_store", "vector_store", "embedding"])
|
||||
@Node.auto(depends_on=["doc_store", "vector_store", "embedding"])
|
||||
def indexing_vector_pipeline(self):
|
||||
return IndexVectorStoreFromDocumentPipeline(
|
||||
doc_store=self.doc_store,
|
||||
|
@ -79,12 +70,9 @@ class ReaderIndexingPipeline(BaseComponent):
|
|||
embedding=self.embedding,
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["chunk_size", "chunk_overlap"])
|
||||
def text_splitter(self):
|
||||
# chunking using NodeParser from llama-index
|
||||
return SimpleNodeParser(
|
||||
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
|
||||
)
|
||||
text_splitter: SimpleNodeParser = SimpleNodeParser.withx(
|
||||
chunk_size=1024, chunk_overlap=256
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
|
|
@ -3,6 +3,7 @@ from pathlib import Path
|
|||
from typing import List
|
||||
|
||||
from theflow import Node, Param
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
|
@ -25,8 +26,6 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
|
||||
storage_path: Path = Path("./storage")
|
||||
retrieval_top_k: int = 3
|
||||
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
|
||||
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
file_name_list: List[str]
|
||||
"""List of filename, incombination with storage_path to
|
||||
create persistent path of vectorstore"""
|
||||
|
@ -35,37 +34,27 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
"The context is: \n{context}\nAnswer: "
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
|
||||
def llm(self):
|
||||
return AzureChatOpenAI(
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=self.openai_api_key,
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
@Param.decorate()
|
||||
def vector_store(self):
|
||||
return InMemoryVectorStore()
|
||||
vector_store: Param[InMemoryVectorStore] = Param(_(InMemoryVectorStore))
|
||||
doc_store: Param[InMemoryDocumentStore] = Param(_(InMemoryDocumentStore))
|
||||
|
||||
@Param.decorate()
|
||||
def doc_store(self):
|
||||
doc_store = InMemoryDocumentStore()
|
||||
return doc_store
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
|
||||
def embedding(self):
|
||||
return AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base=self.openai_api_base,
|
||||
openai_api_key=self.openai_api_key,
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["doc_store", "vector_store", "embedding"])
|
||||
def retrieving_pipeline(self):
|
||||
@Node.default()
|
||||
def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
|
||||
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=self.vector_store,
|
||||
doc_store=self.doc_store,
|
||||
|
|
|
@ -32,5 +32,5 @@ class LLMTool(BaseTool):
|
|||
response = self.llm(query)
|
||||
except ValueError:
|
||||
raise ToolException("LLM Tool call failed")
|
||||
output = response.text[0]
|
||||
output = response.text
|
||||
return output
|
||||
|
|
2
setup.py
2
setup.py
|
@ -30,7 +30,6 @@ setuptools.setup(
|
|||
exclude=("tests", "tests.*", "examples", "examples.*")
|
||||
),
|
||||
install_requires=[
|
||||
"farm-haystack==1.19.0",
|
||||
"langchain",
|
||||
"theflow",
|
||||
"llama-index",
|
||||
|
@ -59,6 +58,7 @@ setuptools.setup(
|
|||
"python-dotenv",
|
||||
"pytest-mock",
|
||||
"unstructured[pdf]",
|
||||
"farm-haystack==1.19.0",
|
||||
],
|
||||
},
|
||||
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import os
|
||||
from typing import List
|
||||
|
||||
from theflow import Node, Param
|
||||
from theflow import Param
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
|
@ -13,35 +14,28 @@ from kotaemon.vectorstores import ChromaVectorStore
|
|||
|
||||
|
||||
class QuestionAnsweringPipeline(BaseComponent):
|
||||
vectorstore_path: str = str("./tmp")
|
||||
retrieval_top_k: int = 1
|
||||
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
|
||||
@Node.decorate(depends_on="openai_api_key")
|
||||
def llm(self):
|
||||
return AzureOpenAI(
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=self.openai_api_key,
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
llm: AzureOpenAI = AzureOpenAI.withx(
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["vectorstore_path", "openai_api_key"])
|
||||
def retrieving_pipeline(self):
|
||||
vector_store = ChromaVectorStore(self.vectorstore_path)
|
||||
embedding = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=self.openai_api_key,
|
||||
)
|
||||
|
||||
return RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=vector_store,
|
||||
embedding=embedding,
|
||||
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
|
||||
RetrieveDocumentFromVectorStorePipeline.withx(
|
||||
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
|
||||
embedding=AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> str:
|
||||
# reload the document store, in case it has been updated
|
||||
|
@ -60,36 +54,27 @@ class QuestionAnsweringPipeline(BaseComponent):
|
|||
prompt = f'Answer the following question: "{text}". The context is: \n{context}'
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
|
||||
return self.llm(prompt).text[0]
|
||||
return self.llm(prompt).text
|
||||
|
||||
|
||||
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
|
||||
# Expose variables for users to switch in prompt ui
|
||||
vectorstore_path: str = str("./tmp")
|
||||
embedding_model: str = "text-embedding-ada-002"
|
||||
deployment: str = "dummy-q2-text-embedding"
|
||||
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
|
||||
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
|
||||
|
||||
@Param.decorate(depends_on=["vectorstore_path"])
|
||||
def vector_store(self):
|
||||
return ChromaVectorStore(self.vectorstore_path)
|
||||
|
||||
@Param.decorate()
|
||||
def doc_store(self):
|
||||
@Param.auto()
|
||||
def doc_store(self) -> InMemoryDocumentStore:
|
||||
doc_store = InMemoryDocumentStore()
|
||||
if os.path.isfile("docstore.json"):
|
||||
doc_store.load("docstore.json")
|
||||
return doc_store
|
||||
|
||||
@Node.decorate(depends_on=["vector_store"])
|
||||
def embedding(self):
|
||||
return AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment=self.deployment,
|
||||
openai_api_base=self.openai_api_base,
|
||||
openai_api_key=self.openai_api_key,
|
||||
)
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> int: # type: ignore
|
||||
"""Normally, this indexing pipeline returns nothing. For demonstration,
|
||||
|
@ -100,7 +85,7 @@ class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
|
|||
|
||||
if self.doc_store is not None:
|
||||
# persist to local anytime an indexing is created
|
||||
# this can be bypassed when we have a FileDocucmentStore
|
||||
# this can be bypassed when we have a FileDocumentStore
|
||||
self.doc_store.save("docstore.json")
|
||||
|
||||
return self.vector_store._collection.count()
|
||||
|
|
|
@ -13,3 +13,17 @@ def mock_google_search(monkeypatch):
|
|||
)
|
||||
|
||||
monkeypatch.setattr(googlesearch, "search", result)
|
||||
|
||||
|
||||
def if_haystack_not_installed():
|
||||
try:
|
||||
import haystack # noqa: F401
|
||||
except ImportError:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
skip_when_haystack_not_installed = pytest.mark.skipif(
|
||||
if_haystack_not_installed(), reason="Haystack is not installed"
|
||||
)
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import tempfile
|
||||
from typing import List
|
||||
|
||||
from theflow import Node
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
|
@ -11,33 +11,27 @@ from kotaemon.vectorstores import ChromaVectorStore
|
|||
|
||||
|
||||
class Pipeline(BaseComponent):
|
||||
vectorstore_path: str = str(tempfile.mkdtemp())
|
||||
llm: Node[AzureOpenAI] = Node(
|
||||
default=AzureOpenAI,
|
||||
default_kwargs={
|
||||
"openai_api_base": "https://test.openai.azure.com/",
|
||||
"openai_api_key": "some-key",
|
||||
"openai_api_version": "2023-03-15-preview",
|
||||
"deployment_name": "gpt35turbo",
|
||||
"temperature": 0,
|
||||
"request_timeout": 60,
|
||||
},
|
||||
llm: AzureOpenAI = AzureOpenAI.withx(
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="gpt35turbo",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["vectorstore_path"])
|
||||
def retrieving_pipeline(self):
|
||||
vector_store = ChromaVectorStore(self.vectorstore_path)
|
||||
embedding = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
)
|
||||
|
||||
return RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=vector_store, embedding=embedding
|
||||
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
|
||||
RetrieveDocumentFromVectorStorePipeline.withx(
|
||||
vector_store=_(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())),
|
||||
embedding=AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="embedding-deployment",
|
||||
openai_api_base="https://test.openai.azure.com/",
|
||||
openai_api_key="some-key",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> str:
|
||||
matched_texts: List[str] = self.retrieving_pipeline(text)
|
||||
return self.llm("\n".join(matched_texts)).text[0]
|
||||
return self.llm("\n".join(matched_texts)).text
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
from kotaemon.documents.base import Document, RetrievedDocument
|
||||
|
||||
from .conftest import skip_when_haystack_not_installed
|
||||
|
||||
|
||||
def test_document_constructor_with_builtin_types():
|
||||
for value in ["str", 1, {}, set(), [], tuple, None]:
|
||||
|
@ -19,7 +19,10 @@ def test_document_constructor_with_document():
|
|||
assert doc2.content == doc1.content
|
||||
|
||||
|
||||
@skip_when_haystack_not_installed
|
||||
def test_document_to_haystack_format():
|
||||
from haystack.schema import Document as HaystackDocument
|
||||
|
||||
text = "Sample text"
|
||||
metadata = {"filename": "sample.txt"}
|
||||
doc = Document(text, metadata=metadata)
|
||||
|
|
|
@ -16,7 +16,6 @@ class TestPromptConfig:
|
|||
assert "text" in config["inputs"], "inputs should have config"
|
||||
|
||||
assert "params" in config, "params should be in config"
|
||||
assert "vectorstore_path" in config["params"]
|
||||
assert "llm.deployment_name" in config["params"]
|
||||
assert "llm.openai_api_base" in config["params"]
|
||||
assert "llm.openai_api_key" in config["params"]
|
||||
|
|
|
@ -42,8 +42,9 @@ def mock_openai_embedding(monkeypatch):
|
|||
)
|
||||
def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
|
||||
indexing_pipeline = ReaderIndexingPipeline(
|
||||
storage=tmp_path, openai_api_key="some-key"
|
||||
storage_path=tmp_path,
|
||||
)
|
||||
indexing_pipeline.embedding.openai_api_key = "some-key"
|
||||
input_file_path = Path(__file__).parent / "resources/dummy.pdf"
|
||||
|
||||
# call ingestion pipeline
|
||||
|
|
|
@ -3,7 +3,7 @@ from pathlib import Path
|
|||
from langchain.schema import Document as LangchainDocument
|
||||
from llama_index.node_parser import SimpleNodeParser
|
||||
|
||||
from kotaemon.documents.base import Document, HaystackDocument
|
||||
from kotaemon.documents.base import Document
|
||||
from kotaemon.loaders import AutoReader
|
||||
|
||||
|
||||
|
@ -19,10 +19,6 @@ def test_pdf_reader():
|
|||
assert isinstance(first_doc, Document)
|
||||
assert first_doc.text.lower().replace(" ", "") == "dummypdffile"
|
||||
|
||||
# check conversion output
|
||||
haystack_doc = first_doc.to_haystack_format()
|
||||
assert isinstance(haystack_doc, HaystackDocument)
|
||||
|
||||
langchain_doc = first_doc.to_langchain_format()
|
||||
assert isinstance(langchain_doc, LangchainDocument)
|
||||
|
||||
|
|
|
@ -3,6 +3,8 @@ import sys
|
|||
|
||||
import pytest
|
||||
|
||||
from .conftest import skip_when_haystack_not_installed
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_artifacts_for_telemetry():
|
||||
|
@ -26,6 +28,7 @@ def clean_artifacts_for_telemetry():
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("clean_artifacts_for_telemetry")
|
||||
@skip_when_haystack_not_installed
|
||||
def test_disable_telemetry_import_haystack_first():
|
||||
"""Test that telemetry is disabled when kotaemon lib is initiated after"""
|
||||
import os
|
||||
|
@ -42,6 +45,7 @@ def test_disable_telemetry_import_haystack_first():
|
|||
|
||||
|
||||
@pytest.mark.usefixtures("clean_artifacts_for_telemetry")
|
||||
@skip_when_haystack_not_installed
|
||||
def test_disable_telemetry_import_haystack_after_kotaemon():
|
||||
"""Test that telemetry is disabled when kotaemon lib is initiated before"""
|
||||
import os
|
||||
|
|
Loading…
Reference in New Issue
Block a user