Upgrade the declarative pipeline for cleaner interface (#51)

This commit is contained in:
Nguyen Trung Duc (john) 2023-10-24 11:12:22 +07:00 committed by GitHub
parent aab982ddc4
commit 9035e25666
26 changed files with 365 additions and 169 deletions

View File

@ -22,4 +22,4 @@ try:
except ImportError:
pass
__version__ = "0.0.4"
__version__ = "0.2.0"

View File

@ -10,4 +10,4 @@ class SimpleRespondentChatbot(BaseChatBot):
llm: Node[ChatLLM]
def _get_message(self) -> str:
return self.llm(self.history).text[0]
return self.llm(self.history).text

View File

@ -63,19 +63,19 @@ def handle_node(node: dict) -> dict:
"""Convert node definition into promptui-compliant config"""
config = {}
for name, param_def in node.get("params", {}).items():
if isinstance(param_def["default_callback"], str):
if isinstance(param_def["auto_callback"], str):
continue
if param_def.get("ignore_ui", False):
continue
config[name] = handle_param(param_def)
for name, node_def in node.get("nodes", {}).items():
if isinstance(node_def["default_callback"], str):
if isinstance(node_def["auto_callback"], str):
continue
if node_def.get("ignore_ui", False):
continue
for key, value in handle_node(node_def["default"]).items():
config[f"{name}.{key}"] = value
for key, value in node_def["default_kwargs"].items():
for key, value in node_def.get("default_kwargs", {}).items():
config[f"{name}.{key}"] = config_from_value(value)
return config
@ -124,11 +124,14 @@ def export_pipeline_to_config(
if ui_type == "chat":
params = {f".bot.{k}": v for k, v in handle_node(pipeline_def).items()}
params["system_message"] = {"component": "text", "params": {"value": ""}}
outputs = []
if hasattr(pipeline, "_promptui_outputs"):
outputs = pipeline._promptui_outputs
config_obj: dict = {
"ui-type": ui_type,
"params": params,
"inputs": {},
"outputs": [],
"outputs": outputs,
"logs": {
"full_pipeline": {
"input": {

View File

@ -61,6 +61,9 @@ def from_log_to_dict(pipeline_cls: Type[BaseComponent], log_config: dict) -> dic
if name not in logged_infos:
logged_infos[name] = [None] * len(dirs)
if step not in progress:
continue
info = progress[step]
if getter:
if getter in allowed_resultlog_callbacks:

View File

@ -13,9 +13,9 @@ class John(Base):
primary_hue: colors.Color | str = colors.neutral,
secondary_hue: colors.Color | str = colors.neutral,
neutral_hue: colors.Color | str = colors.neutral,
spacing_size: sizes.Size | str = sizes.spacing_lg,
spacing_size: sizes.Size | str = sizes.spacing_sm,
radius_size: sizes.Size | str = sizes.radius_none,
text_size: sizes.Size | str = sizes.text_md,
text_size: sizes.Size | str = sizes.text_sm,
font: fonts.Font
| str
| Iterable[fonts.Font | str] = (
@ -79,8 +79,8 @@ class John(Base):
button_cancel_background_fill_hover="*button_primary_background_fill_hover",
button_cancel_text_color="*button_primary_text_color",
# Padding
checkbox_label_padding="*spacing_md",
button_large_padding="*spacing_lg",
checkbox_label_padding="*spacing_sm",
button_large_padding="*spacing_sm",
button_small_padding="*spacing_sm",
# Borders
block_border_width="0px",
@ -91,5 +91,5 @@ class John(Base):
# Block Labels
block_title_text_weight="600",
block_label_text_weight="600",
block_label_text_size="*text_md",
block_label_text_size="*text_sm",
)

View File

@ -26,9 +26,9 @@ def build_from_dict(config: Union[str, dict]):
for key, value in config_dict.items():
pipeline_def = import_dotted_string(key, safe=False)
if value["ui-type"] == "chat":
demos.append(build_chat_ui(value, pipeline_def))
demos.append(build_chat_ui(value, pipeline_def).queue())
else:
demos.append(build_pipeline_ui(value, pipeline_def))
demos.append(build_pipeline_ui(value, pipeline_def).queue())
if len(demos) == 1:
demo = demos[0]
else:

View File

@ -0,0 +1,181 @@
from __future__ import annotations
from typing import Any, AsyncGenerator
import anyio
from gradio import ChatInterface
from gradio.components import IOComponent, get_component_instance
from gradio.events import on
from gradio.helpers import special_args
from gradio.routes import Request
class ChatBlock(ChatInterface):
"""The ChatBlock subclasses ChatInterface to provide extra functionalities:
- Show additional outputs to the chat interface
- Disallow blank user message
"""
def __init__(
self,
*args,
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
**kwargs,
):
if additional_outputs:
if not isinstance(additional_outputs, list):
additional_outputs = [additional_outputs]
self.additional_outputs = [
get_component_instance(i) for i in additional_outputs # type: ignore
]
else:
self.additional_outputs = []
super().__init__(*args, **kwargs)
async def _submit_fn(
self,
message: str,
history_with_input: list[list[str | None]],
request: Request,
*args,
) -> tuple[Any, ...]:
input_args = args[: -len(self.additional_outputs)]
output_args = args[-len(self.additional_outputs) :]
if not message:
return history_with_input, history_with_input, *output_args
history = history_with_input[:-1]
inputs, _, _ = special_args(
self.fn, inputs=[message, history, *input_args], request=request
)
if self.is_async:
response = await self.fn(*inputs)
else:
response = await anyio.to_thread.run_sync(
self.fn, *inputs, limiter=self.limiter
)
output = []
if self.additional_outputs:
text = response[0]
output = response[1:]
else:
text = response
history.append([message, text])
return history, history, *output
async def _stream_fn(
self,
message: str,
history_with_input: list[list[str | None]],
*args,
) -> AsyncGenerator:
raise NotImplementedError("Stream function not implemented for ChatBlock")
def _display_input(
self, message: str, history: list[list[str | None]]
) -> tuple[list[list[str | None]], list[list[str | None]]]:
"""Stop displaying the input message if the message is a blank string"""
if not message:
return history, history
return super()._display_input(message, history)
def _setup_events(self) -> None:
"""Include additional outputs in the submit event"""
submit_fn = self._stream_fn if self.is_generator else self._submit_fn
submit_triggers = (
[self.textbox.submit, self.submit_btn.click]
if self.submit_btn
else [self.textbox.submit]
)
submit_event = (
on(
submit_triggers,
self._clear_and_save_textbox,
[self.textbox],
[self.textbox, self.saved_input],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events(submit_triggers, submit_event)
if self.retry_btn:
retry_event = (
self.retry_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
self._display_input,
[self.saved_input, self.chatbot_state],
[self.chatbot, self.chatbot_state],
api_name=False,
queue=False,
)
.then(
submit_fn,
[self.saved_input, self.chatbot_state]
+ self.additional_inputs
+ self.additional_outputs,
[self.chatbot, self.chatbot_state] + self.additional_outputs,
api_name=False,
)
)
self._setup_stop_events([self.retry_btn.click], retry_event)
if self.undo_btn:
self.undo_btn.click(
self._delete_prev_fn,
[self.chatbot_state],
[self.chatbot, self.saved_input, self.chatbot_state],
api_name=False,
queue=False,
).then(
lambda x: x,
[self.saved_input],
[self.textbox],
api_name=False,
queue=False,
)
if self.clear_btn:
self.clear_btn.click(
lambda: ([], [], None),
None,
[self.chatbot, self.chatbot_state, self.saved_input],
queue=False,
api_name=False,
)
def _setup_api(self) -> None:
api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn
self.fake_api_btn.click(
api_fn,
[self.textbox, self.chatbot_state] + self.additional_inputs,
[self.textbox, self.chatbot_state] + self.additional_outputs,
api_name="chat",
)

View File

@ -8,6 +8,9 @@ from theflow.storage import storage
from kotaemon.chatbot import ChatConversation
from kotaemon.contribs.promptui.base import get_component
from kotaemon.contribs.promptui.export import export
from kotaemon.contribs.promptui.ui.blocks import ChatBlock
from ..logs import ResultLog
USAGE_INSTRUCTION = """## How to use:
@ -87,8 +90,10 @@ def construct_chat_ui(
outputs.append(component)
sess = gr.State(value=None)
chatbot = gr.Chatbot(label="Chatbot")
chat = gr.ChatInterface(func_chat, chatbot=chatbot, additional_inputs=[sess])
chatbot = gr.Chatbot(label="Chatbot", show_copy_button=True)
chat = ChatBlock(
func_chat, chatbot=chatbot, additional_inputs=[sess], additional_outputs=outputs
)
param_state = gr.Textbox(interactive=False)
with gr.Blocks(analytics_enabled=False, title="Welcome to PromptUI") as demo:
@ -106,6 +111,7 @@ def construct_chat_ui(
chat.saved_input,
param_state,
sess,
*outputs,
],
)
with gr.Accordion(label="End chat", open=False):
@ -162,6 +168,9 @@ def build_chat_ui(config, pipeline_def):
exported_dir = output_dir.parent / "exported"
exported_dir.mkdir(parents=True, exist_ok=True)
resultlog = getattr(pipeline_def, "_promptui_resultlog", ResultLog)
allowed_resultlog_callbacks = {i for i in dir(resultlog) if not i.startswith("__")}
def new_chat(*args):
"""Start a new chat function
@ -190,7 +199,14 @@ def build_chat_ui(config, pipeline_def):
)
gr.Info("New chat session started.")
return [], [], None, param_state_str, session
return (
[],
[],
None,
param_state_str,
session,
*[None] * len(config.get("outputs", [])),
)
def chat(message, history, session, *args):
"""The chat interface
@ -212,7 +228,18 @@ def build_chat_ui(config, pipeline_def):
"No active chat session. Please set the params and click New chat"
)
return session(message).content
pred = session(message)
text_response = pred.content
additional_outputs = []
for output_def in config.get("outputs", []):
value = session.last_run.logs(output_def["step"])
getter = output_def.get("getter", None)
if getter and getter in allowed_resultlog_callbacks:
value = getattr(resultlog, getter)(value)
additional_outputs.append(value)
return text_response, *additional_outputs
def end_chat(preference: str, save_log: bool, session):
"""End the chat session

View File

@ -1,9 +1,10 @@
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional, TypeVar
from haystack.schema import Document as HaystackDocument
from llama_index.bridge.pydantic import Field
from llama_index.schema import Document as BaseDocument
from pyparsing import TypeVar
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon"
@ -49,6 +50,8 @@ class Document(BaseDocument):
def to_haystack_format(self) -> "HaystackDocument":
"""Convert struct to Haystack document format."""
from haystack.schema import Document as HaystackDocument
metadata = self.metadata or {}
text = self.text
return HaystackDocument(content=text, meta=metadata)

View File

@ -56,11 +56,11 @@ class LangchainEmbeddings(BaseEmbeddings):
def __setattr__(self, name, value):
if name in self._lc_class.__fields__:
setattr(self.agent, name, value)
self._kwargs[name] = value
else:
super().__setattr__(name, value)
@Param.decorate(no_cache=True)
@Param.auto(cache=False)
def agent(self):
return self._lc_class(**self._kwargs)

View File

@ -6,7 +6,7 @@ from kotaemon.documents.base import Document
class LLMInterface(Document):
candidates: List[str]
candidates: List[str] = Field(default_factory=list)
completion_tokens: int = -1
total_tokens: int = -1
prompt_tokens: int = -1

View File

@ -40,7 +40,7 @@ class LangchainChatLLM(ChatLLM):
self._kwargs[param] = params.pop(param)
super().__init__(**params)
@Param.decorate(no_cache=True)
@Param.auto(cache=False)
def agent(self) -> BaseLanguageModel:
return self._lc_class(**self._kwargs)
@ -92,3 +92,9 @@ class LangchainChatLLM(ChatLLM):
setattr(self.agent, name, value)
else:
super().__setattr__(name, value)
def __getattr__(self, name):
if name in self._lc_class.__fields__:
getattr(self.agent, name)
else:
super().__getattr__(name)

View File

@ -27,7 +27,7 @@ class LangchainLLM(LLM):
self._kwargs[param] = params.pop(param)
super().__init__(**params)
@Param.decorate(no_cache=True)
@Param.auto(cache=False)
def agent(self):
return self._lc_class(**self._kwargs)

View File

@ -69,8 +69,8 @@ class Thought(BaseComponent):
"variable placeholders, that then will be subsituted with real values when "
"this component is executed"
)
llm = Node(
default=AzureChatOpenAI, help="The LLM model to execute the input prompt"
llm: Node[BaseComponent] = Node(
AzureChatOpenAI, help="The LLM model to execute the input prompt"
)
post_process: Node[Compose] = Node(
help="The function post-processor that post-processes LLM output prediction ."
@ -78,7 +78,7 @@ class Thought(BaseComponent):
"a dictionary, where the key should"
)
@Node.decorate(depends_on="prompt")
@Node.auto(depends_on="prompt")
def prompt_template(self):
"""Automatically wrap around param prompt. Can ignore"""
return BasePromptComponent(self.prompt)

View File

@ -1,8 +1,10 @@
import os
from pathlib import Path
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union
from theflow import Node, Param
from llama_index.readers.base import BaseReader
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore
@ -32,33 +34,22 @@ class ReaderIndexingPipeline(BaseComponent):
# Expose variables for users to switch in prompt ui
storage_path: Path = Path("./storage")
reader_name: str = "normal" # "normal" or "mathpix"
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
chunk_size: int = 1024
chunk_overlap: int = 256
file_name_list: List[str] = list()
vector_store: _[InMemoryVectorStore] = _(InMemoryVectorStore)
doc_store: _[InMemoryDocumentStore] = _(InMemoryDocumentStore)
@Param.decorate()
def vector_store(self):
return InMemoryVectorStore()
@Param.decorate()
def doc_store(self):
doc_store = InMemoryDocumentStore()
return doc_store
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
def embedding(self):
return AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
def get_reader(self, input_files: List[Union[str, Path]]):
# document parsers
file_extractor = {
file_extractor: Dict[str, BaseReader] = {
".xlsx": PandasExcelReader(),
}
if self.reader_name == "normal":
@ -71,7 +62,7 @@ class ReaderIndexingPipeline(BaseComponent):
)
return main_reader
@Node.decorate(depends_on=["doc_store", "vector_store", "embedding"])
@Node.auto(depends_on=["doc_store", "vector_store", "embedding"])
def indexing_vector_pipeline(self):
return IndexVectorStoreFromDocumentPipeline(
doc_store=self.doc_store,
@ -79,12 +70,9 @@ class ReaderIndexingPipeline(BaseComponent):
embedding=self.embedding,
)
@Node.decorate(depends_on=["chunk_size", "chunk_overlap"])
def text_splitter(self):
# chunking using NodeParser from llama-index
return SimpleNodeParser(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
text_splitter: SimpleNodeParser = SimpleNodeParser.withx(
chunk_size=1024, chunk_overlap=256
)
def run(
self,

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import List
from theflow import Node, Param
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore
@ -25,8 +26,6 @@ class QuestionAnsweringPipeline(BaseComponent):
storage_path: Path = Path("./storage")
retrieval_top_k: int = 3
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
file_name_list: List[str]
"""List of filename, incombination with storage_path to
create persistent path of vectorstore"""
@ -35,37 +34,27 @@ class QuestionAnsweringPipeline(BaseComponent):
"The context is: \n{context}\nAnswer: "
)
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
def llm(self):
return AzureChatOpenAI(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=self.openai_api_key,
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
@Param.decorate()
def vector_store(self):
return InMemoryVectorStore()
vector_store: Param[InMemoryVectorStore] = Param(_(InMemoryVectorStore))
doc_store: Param[InMemoryDocumentStore] = Param(_(InMemoryDocumentStore))
@Param.decorate()
def doc_store(self):
doc_store = InMemoryDocumentStore()
return doc_store
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
@Node.decorate(depends_on=["openai_api_base", "openai_api_key"])
def embedding(self):
return AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
)
@Node.decorate(depends_on=["doc_store", "vector_store", "embedding"])
def retrieving_pipeline(self):
@Node.default()
def retrieving_pipeline(self) -> RetrieveDocumentFromVectorStorePipeline:
retrieving_pipeline = RetrieveDocumentFromVectorStorePipeline(
vector_store=self.vector_store,
doc_store=self.doc_store,

View File

@ -32,5 +32,5 @@ class LLMTool(BaseTool):
response = self.llm(query)
except ValueError:
raise ToolException("LLM Tool call failed")
output = response.text[0]
output = response.text
return output

View File

@ -30,7 +30,6 @@ setuptools.setup(
exclude=("tests", "tests.*", "examples", "examples.*")
),
install_requires=[
"farm-haystack==1.19.0",
"langchain",
"theflow",
"llama-index",
@ -59,6 +58,7 @@ setuptools.setup(
"python-dotenv",
"pytest-mock",
"unstructured[pdf]",
"farm-haystack==1.19.0",
],
},
entry_points={"console_scripts": ["kh=kotaemon.cli:main"]},

View File

@ -1,7 +1,8 @@
import os
from typing import List
from theflow import Node, Param
from theflow import Param
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore
@ -13,35 +14,28 @@ from kotaemon.vectorstores import ChromaVectorStore
class QuestionAnsweringPipeline(BaseComponent):
vectorstore_path: str = str("./tmp")
retrieval_top_k: int = 1
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
@Node.decorate(depends_on="openai_api_key")
def llm(self):
return AzureOpenAI(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=self.openai_api_key,
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
llm: AzureOpenAI = AzureOpenAI.withx(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
@Node.decorate(depends_on=["vectorstore_path", "openai_api_key"])
def retrieving_pipeline(self):
vector_store = ChromaVectorStore(self.vectorstore_path)
embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=self.openai_api_key,
)
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store,
embedding=embedding,
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
RetrieveDocumentFromVectorStorePipeline.withx(
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
),
)
)
def run_raw(self, text: str) -> str:
# reload the document store, in case it has been updated
@ -60,36 +54,27 @@ class QuestionAnsweringPipeline(BaseComponent):
prompt = f'Answer the following question: "{text}". The context is: \n{context}'
self.log_progress(".prompt", prompt=prompt)
return self.llm(prompt).text[0]
return self.llm(prompt).text
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
# Expose variables for users to switch in prompt ui
vectorstore_path: str = str("./tmp")
embedding_model: str = "text-embedding-ada-002"
deployment: str = "dummy-q2-text-embedding"
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
@Param.decorate(depends_on=["vectorstore_path"])
def vector_store(self):
return ChromaVectorStore(self.vectorstore_path)
@Param.decorate()
def doc_store(self):
@Param.auto()
def doc_store(self) -> InMemoryDocumentStore:
doc_store = InMemoryDocumentStore()
if os.path.isfile("docstore.json"):
doc_store.load("docstore.json")
return doc_store
@Node.decorate(depends_on=["vector_store"])
def embedding(self):
return AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment=self.deployment,
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
def run_raw(self, text: str) -> int: # type: ignore
"""Normally, this indexing pipeline returns nothing. For demonstration,
@ -100,7 +85,7 @@ class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
if self.doc_store is not None:
# persist to local anytime an indexing is created
# this can be bypassed when we have a FileDocucmentStore
# this can be bypassed when we have a FileDocumentStore
self.doc_store.save("docstore.json")
return self.vector_store._collection.count()

View File

@ -13,3 +13,17 @@ def mock_google_search(monkeypatch):
)
monkeypatch.setattr(googlesearch, "search", result)
def if_haystack_not_installed():
try:
import haystack # noqa: F401
except ImportError:
return True
else:
return False
skip_when_haystack_not_installed = pytest.mark.skipif(
if_haystack_not_installed(), reason="Haystack is not installed"
)

View File

@ -1,7 +1,7 @@
import tempfile
from typing import List
from theflow import Node
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.embeddings import AzureOpenAIEmbeddings
@ -11,33 +11,27 @@ from kotaemon.vectorstores import ChromaVectorStore
class Pipeline(BaseComponent):
vectorstore_path: str = str(tempfile.mkdtemp())
llm: Node[AzureOpenAI] = Node(
default=AzureOpenAI,
default_kwargs={
"openai_api_base": "https://test.openai.azure.com/",
"openai_api_key": "some-key",
"openai_api_version": "2023-03-15-preview",
"deployment_name": "gpt35turbo",
"temperature": 0,
"request_timeout": 60,
},
llm: AzureOpenAI = AzureOpenAI.withx(
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",
deployment_name="gpt35turbo",
temperature=0,
request_timeout=60,
)
@Node.decorate(depends_on=["vectorstore_path"])
def retrieving_pipeline(self):
vector_store = ChromaVectorStore(self.vectorstore_path)
embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
)
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store, embedding=embedding
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
RetrieveDocumentFromVectorStorePipeline.withx(
vector_store=_(ChromaVectorStore).withx(path=str(tempfile.mkdtemp())),
embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="embedding-deployment",
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
),
)
)
def run_raw(self, text: str) -> str:
matched_texts: List[str] = self.retrieving_pipeline(text)
return self.llm("\n".join(matched_texts)).text[0]
return self.llm("\n".join(matched_texts)).text

View File

@ -1,7 +1,7 @@
from haystack.schema import Document as HaystackDocument
from kotaemon.documents.base import Document, RetrievedDocument
from .conftest import skip_when_haystack_not_installed
def test_document_constructor_with_builtin_types():
for value in ["str", 1, {}, set(), [], tuple, None]:
@ -19,7 +19,10 @@ def test_document_constructor_with_document():
assert doc2.content == doc1.content
@skip_when_haystack_not_installed
def test_document_to_haystack_format():
from haystack.schema import Document as HaystackDocument
text = "Sample text"
metadata = {"filename": "sample.txt"}
doc = Document(text, metadata=metadata)

View File

@ -16,7 +16,6 @@ class TestPromptConfig:
assert "text" in config["inputs"], "inputs should have config"
assert "params" in config, "params should be in config"
assert "vectorstore_path" in config["params"]
assert "llm.deployment_name" in config["params"]
assert "llm.openai_api_base" in config["params"]
assert "llm.openai_api_key" in config["params"]

View File

@ -42,8 +42,9 @@ def mock_openai_embedding(monkeypatch):
)
def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
indexing_pipeline = ReaderIndexingPipeline(
storage=tmp_path, openai_api_key="some-key"
storage_path=tmp_path,
)
indexing_pipeline.embedding.openai_api_key = "some-key"
input_file_path = Path(__file__).parent / "resources/dummy.pdf"
# call ingestion pipeline

View File

@ -3,7 +3,7 @@ from pathlib import Path
from langchain.schema import Document as LangchainDocument
from llama_index.node_parser import SimpleNodeParser
from kotaemon.documents.base import Document, HaystackDocument
from kotaemon.documents.base import Document
from kotaemon.loaders import AutoReader
@ -19,10 +19,6 @@ def test_pdf_reader():
assert isinstance(first_doc, Document)
assert first_doc.text.lower().replace(" ", "") == "dummypdffile"
# check conversion output
haystack_doc = first_doc.to_haystack_format()
assert isinstance(haystack_doc, HaystackDocument)
langchain_doc = first_doc.to_langchain_format()
assert isinstance(langchain_doc, LangchainDocument)

View File

@ -3,6 +3,8 @@ import sys
import pytest
from .conftest import skip_when_haystack_not_installed
@pytest.fixture
def clean_artifacts_for_telemetry():
@ -26,6 +28,7 @@ def clean_artifacts_for_telemetry():
@pytest.mark.usefixtures("clean_artifacts_for_telemetry")
@skip_when_haystack_not_installed
def test_disable_telemetry_import_haystack_first():
"""Test that telemetry is disabled when kotaemon lib is initiated after"""
import os
@ -42,6 +45,7 @@ def test_disable_telemetry_import_haystack_first():
@pytest.mark.usefixtures("clean_artifacts_for_telemetry")
@skip_when_haystack_not_installed
def test_disable_telemetry_import_haystack_after_kotaemon():
"""Test that telemetry is disabled when kotaemon lib is initiated before"""
import os