[AUR-387, AUR-425] Add start-project to CLI (#29)
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from theflow import Node, Param
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.docstores import InMemoryDocumentStore
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.llms.completions.openai import AzureOpenAI
|
||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.vectorstores import ChromaVectorStore
|
||||
|
||||
|
||||
class QuestionAnsweringPipeline(BaseComponent):
|
||||
vectorstore_path: str = str("./tmp")
|
||||
retrieval_top_k: int = 1
|
||||
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
|
||||
@Node.decorate(depends_on="openai_api_key")
|
||||
def llm(self):
|
||||
return AzureOpenAI(
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=self.openai_api_key,
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
@Node.decorate(depends_on=["vectorstore_path", "openai_api_key"])
|
||||
def retrieving_pipeline(self):
|
||||
vector_store = ChromaVectorStore(self.vectorstore_path)
|
||||
embedding = AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=self.openai_api_key,
|
||||
)
|
||||
|
||||
return RetrieveDocumentFromVectorStorePipeline(
|
||||
vector_store=vector_store,
|
||||
embedding=embedding,
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> str:
|
||||
# reload the document store, in case it has been updated
|
||||
doc_store = InMemoryDocumentStore()
|
||||
doc_store.load("docstore.json")
|
||||
self.retrieving_pipeline.doc_store = doc_store
|
||||
|
||||
# retrieve relevant documents as context
|
||||
matched_texts: List[str] = [
|
||||
_.text
|
||||
for _ in self.retrieving_pipeline(text, top_k=int(self.retrieval_top_k))
|
||||
]
|
||||
context = "\n".join(matched_texts)
|
||||
|
||||
# generate the answer
|
||||
prompt = f'Answer the following question: "{text}". The context is: \n{context}'
|
||||
self.log_progress(".prompt", prompt=prompt)
|
||||
|
||||
return self.llm(prompt).text[0]
|
||||
|
||||
|
||||
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
|
||||
# Expose variables for users to switch in prompt ui
|
||||
vectorstore_path: str = str("./tmp")
|
||||
embedding_model: str = "text-embedding-ada-002"
|
||||
deployment: str = "dummy-q2-text-embedding"
|
||||
openai_api_base: str = "https://bleh-dummy-2.openai.azure.com/"
|
||||
openai_api_key: str = os.environ.get("OPENAI_API_KEY", "")
|
||||
|
||||
@Param.decorate(depends_on=["vectorstore_path"])
|
||||
def vector_store(self):
|
||||
return ChromaVectorStore(self.vectorstore_path)
|
||||
|
||||
@Param.decorate()
|
||||
def doc_store(self):
|
||||
doc_store = InMemoryDocumentStore()
|
||||
if os.path.isfile("docstore.json"):
|
||||
doc_store.load("docstore.json")
|
||||
return doc_store
|
||||
|
||||
@Node.decorate(depends_on=["vector_store"])
|
||||
def embedding(self):
|
||||
return AzureOpenAIEmbeddings(
|
||||
model="text-embedding-ada-002",
|
||||
deployment=self.deployment,
|
||||
openai_api_base=self.openai_api_base,
|
||||
openai_api_key=self.openai_api_key,
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> int: # type: ignore
|
||||
"""Normally, this indexing pipeline returns nothing. For demonstration,
|
||||
we want it to return something, so let's return the number of documents
|
||||
in the vector store
|
||||
"""
|
||||
super().run_raw(text)
|
||||
|
||||
if self.doc_store is not None:
|
||||
# persist to local anytime an indexing is created
|
||||
# this can be bypassed when we have a FileDocucmentStore
|
||||
self.doc_store.save("docstore.json")
|
||||
|
||||
return self.vector_store._collection.count()
|
Reference in New Issue
Block a user