Upgrade promptui to conform to Gradio V4 (#98)
This commit is contained in:
committed by
GitHub
parent
797df5a69c
commit
1f927d3391
@@ -1,47 +1,66 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from theflow import Param
|
||||
from theflow import Node, Param
|
||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||
|
||||
from kotaemon.base import BaseComponent
|
||||
from kotaemon.base import BaseComponent, Document, LLMInterface
|
||||
from kotaemon.contribs.promptui.logs import ResultLog
|
||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||
from kotaemon.llms import AzureOpenAI
|
||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
||||
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
|
||||
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||
from kotaemon.llms import AzureChatOpenAI
|
||||
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
|
||||
|
||||
|
||||
class QAResultLog(ResultLog):
|
||||
@staticmethod
|
||||
def _get_prompt(obj):
|
||||
return obj["prompt"]
|
||||
|
||||
|
||||
class QuestionAnsweringPipeline(BaseComponent):
|
||||
retrieval_top_k: int = 1
|
||||
|
||||
llm: AzureOpenAI = AzureOpenAI.withx(
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
_promptui_resultlog = QAResultLog
|
||||
_promptui_outputs: list = [
|
||||
{
|
||||
"step": ".prompt",
|
||||
"getter": "_get_prompt",
|
||||
"component": "text",
|
||||
"params": {"label": "Constructed prompt to LLM"},
|
||||
},
|
||||
{
|
||||
"step": ".",
|
||||
"getter": "_get_output",
|
||||
"component": "text",
|
||||
"params": {"label": "Answer"},
|
||||
},
|
||||
]
|
||||
|
||||
retrieval_top_k: int = 1
|
||||
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||
openai_api_version="2023-03-15-preview",
|
||||
deployment_name="dummy-q2-gpt35",
|
||||
temperature=0,
|
||||
request_timeout=60,
|
||||
)
|
||||
|
||||
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
|
||||
RetrieveDocumentFromVectorStorePipeline.withx(
|
||||
retrieving_pipeline: VectorRetrieval = Node(
|
||||
VectorRetrieval.withx(
|
||||
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
|
||||
doc_store=_(SimpleFileDocumentStore).withx(path="docstore.json"),
|
||||
embedding=AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||
),
|
||||
)
|
||||
),
|
||||
ignore_ui=True,
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> str:
|
||||
# reload the document store, in case it has been updated
|
||||
doc_store = InMemoryDocumentStore()
|
||||
doc_store.load("docstore.json")
|
||||
self.retrieving_pipeline.doc_store = doc_store
|
||||
|
||||
def run(self, text: str) -> LLMInterface:
|
||||
# retrieve relevant documents as context
|
||||
matched_texts: List[str] = [
|
||||
_.text
|
||||
@@ -56,35 +75,33 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||
return self.llm(prompt).text
|
||||
|
||||
|
||||
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
|
||||
# Expose variables for users to switch in prompt ui
|
||||
embedding_model: str = "text-embedding-ada-002"
|
||||
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
|
||||
|
||||
@Param.auto()
|
||||
def doc_store(self) -> InMemoryDocumentStore:
|
||||
doc_store = InMemoryDocumentStore()
|
||||
if os.path.isfile("docstore.json"):
|
||||
doc_store.load("docstore.json")
|
||||
return doc_store
|
||||
class IndexingPipeline(VectorIndexing):
|
||||
|
||||
vector_store: ChromaVectorStore = Param(
|
||||
_(ChromaVectorStore).withx(path="./tmp"),
|
||||
ignore_ui=True,
|
||||
)
|
||||
doc_store: SimpleFileDocumentStore = Param(
|
||||
_(SimpleFileDocumentStore).withx(path="docstore.json"),
|
||||
ignore_ui=True,
|
||||
)
|
||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||
model="text-embedding-ada-002",
|
||||
deployment="dummy-q2-text-embedding",
|
||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
||||
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||
)
|
||||
|
||||
def run_raw(self, text: str) -> int: # type: ignore
|
||||
def run(self, text: str) -> Document:
|
||||
"""Normally, this indexing pipeline returns nothing. For demonstration,
|
||||
we want it to return something, so let's return the number of documents
|
||||
in the vector store
|
||||
"""
|
||||
super().run_raw(text)
|
||||
super().run(text)
|
||||
|
||||
if self.doc_store is not None:
|
||||
# persist to local anytime an indexing is created
|
||||
# this can be bypassed when we have a FileDocumentStore
|
||||
self.doc_store.save("docstore.json")
|
||||
|
||||
return self.vector_store._collection.count()
|
||||
return Document(self.vector_store._collection.count())
|
||||
|
Reference in New Issue
Block a user