Upgrade the declarative pipeline for cleaner interface (#51)

This commit is contained in:
Nguyen Trung Duc (john)
2023-10-24 11:12:22 +07:00
committed by GitHub
parent aab982ddc4
commit 9035e25666
26 changed files with 365 additions and 169 deletions

View File

@@ -1,7 +1,8 @@
import os
from typing import List
from theflow import Node, Param
from theflow import Param
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.docstores import InMemoryDocumentStore
@@ -13,35 +14,28 @@ from kotaemon.vectorstores import ChromaVectorStore
class QuestionAnsweringPipeline(BaseComponent):
vectorstore_path: str = str("./tmp")
retrieval_top_k: int = 1
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
@Node.decorate(depends_on="openai_api_key")
def llm(self):
return AzureOpenAI(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=self.openai_api_key,
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
llm: AzureOpenAI = AzureOpenAI.withx(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
@Node.decorate(depends_on=["vectorstore_path", "openai_api_key"])
def retrieving_pipeline(self):
vector_store = ChromaVectorStore(self.vectorstore_path)
embedding = AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=self.openai_api_key,
)
return RetrieveDocumentFromVectorStorePipeline(
vector_store=vector_store,
embedding=embedding,
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
RetrieveDocumentFromVectorStorePipeline.withx(
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
),
)
)
def run_raw(self, text: str) -> str:
# reload the document store, in case it has been updated
@@ -60,36 +54,27 @@ class QuestionAnsweringPipeline(BaseComponent):
prompt = f'Answer the following question: "{text}". The context is: \n{context}'
self.log_progress(".prompt", prompt=prompt)
return self.llm(prompt).text[0]
return self.llm(prompt).text
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
# Expose variables for users to switch in prompt ui
vectorstore_path: str = str("./tmp")
embedding_model: str = "text-embedding-ada-002"
deployment: str = "dummy-q2-text-embedding"
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
@Param.decorate(depends_on=["vectorstore_path"])
def vector_store(self):
return ChromaVectorStore(self.vectorstore_path)
@Param.decorate()
def doc_store(self):
@Param.auto()
def doc_store(self) -> InMemoryDocumentStore:
doc_store = InMemoryDocumentStore()
if os.path.isfile("docstore.json"):
doc_store.load("docstore.json")
return doc_store
@Node.decorate(depends_on=["vector_store"])
def embedding(self):
return AzureOpenAIEmbeddings(
model="text-embedding-ada-002",
deployment=self.deployment,
openai_api_base=self.openai_api_base,
openai_api_key=self.openai_api_key,
)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
)
def run_raw(self, text: str) -> int: # type: ignore
"""Normally, this indexing pipeline returns nothing. For demonstration,
@@ -100,7 +85,7 @@ class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
if self.doc_store is not None:
# persist to local anytime an indexing is created
# this can be bypassed when we have a FileDocucmentStore
# this can be bypassed when we have a FileDocumentStore
self.doc_store.save("docstore.json")
return self.vector_store._collection.count()