Upgrade promptui to conform to Gradio V4 (#98)

This commit is contained in:
Duc Nguyen (john) 2023-12-07 15:24:07 +07:00 committed by GitHub
parent 797df5a69c
commit 1f927d3391
7 changed files with 68 additions and 56 deletions

View File

@ -84,9 +84,7 @@ def handle_node(node: dict) -> dict:
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict: def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
"""Get the input from the pipeline""" """Get the input from the pipeline"""
if not hasattr(pipeline, "run_raw"): signature = inspect.signature(pipeline.run)
return {}
signature = inspect.signature(pipeline.run_raw)
inputs: Dict[str, Dict] = {} inputs: Dict[str, Dict] = {}
for name, param in signature.parameters.items(): for name, param in signature.parameters.items():
if name in ["self", "args", "kwargs"]: if name in ["self", "args", "kwargs"]:

View File

@ -4,7 +4,7 @@ from typing import Any, AsyncGenerator
import anyio import anyio
from gradio import ChatInterface from gradio import ChatInterface
from gradio.components import IOComponent, get_component_instance from gradio.components import Component, get_component_instance
from gradio.events import on from gradio.events import on
from gradio.helpers import special_args from gradio.helpers import special_args
from gradio.routes import Request from gradio.routes import Request
@ -20,7 +20,7 @@ class ChatBlock(ChatInterface):
def __init__( def __init__(
self, self,
*args, *args,
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None, additional_outputs: str | Component | list[str | Component] | None = None,
**kwargs, **kwargs,
): ):
if additional_outputs: if additional_outputs:

View File

@ -44,7 +44,7 @@ class VectorIndexing(BaseIndexing):
qa_pipeline=CitationQAPipeline(**kwargs), qa_pipeline=CitationQAPipeline(**kwargs),
) )
def run(self, text: str | list[str] | Document | list[Document]) -> None: def run(self, text: str | list[str] | Document | list[Document]):
input_: list[Document] = [] input_: list[Document] = []
if not isinstance(text, list): if not isinstance(text, list):
text = [text] text = [text]

View File

@ -111,7 +111,7 @@ class OpenAI(LCCompletionMixin, LLM):
openai_api_base: Optional[str] = None, openai_api_base: Optional[str] = None,
model_name: str = "text-davinci-003", model_name: str = "text-davinci-003",
temperature: float = 0.7, temperature: float = 0.7,
max_token: int = 256, max_tokens: int = 256,
top_p: float = 1, top_p: float = 1,
frequency_penalty: float = 0, frequency_penalty: float = 0,
n: int = 1, n: int = 1,
@ -126,7 +126,7 @@ class OpenAI(LCCompletionMixin, LLM):
openai_api_base=openai_api_base, openai_api_base=openai_api_base,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_token=max_token, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
n=n, n=n,
@ -154,7 +154,7 @@ class AzureOpenAI(LCCompletionMixin, LLM):
openai_api_key: Optional[str] = None, openai_api_key: Optional[str] = None,
model_name: str = "text-davinci-003", model_name: str = "text-davinci-003",
temperature: float = 0.7, temperature: float = 0.7,
max_token: int = 256, max_tokens: int = 256,
top_p: float = 1, top_p: float = 1,
frequency_penalty: float = 0, frequency_penalty: float = 0,
n: int = 1, n: int = 1,
@ -171,7 +171,7 @@ class AzureOpenAI(LCCompletionMixin, LLM):
openai_api_key=openai_api_key, openai_api_key=openai_api_key,
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_token=max_token, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
frequency_penalty=frequency_penalty, frequency_penalty=frequency_penalty,
n=n, n=n,

View File

@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies # metadata and dependencies
[project] [project]
name = "kotaemon" name = "kotaemon"
version = "0.3.4" version = "0.3.5"
requires-python = ">= 3.10" requires-python = ">= 3.10"
description = "Kotaemon core library for AI development." description = "Kotaemon core library for AI development."
dependencies = [ dependencies = [
@ -19,7 +19,7 @@ dependencies = [
"theflow", "theflow",
"llama-index>=0.9.0", "llama-index>=0.9.0",
"llama-hub", "llama-hub",
"gradio", "gradio>=4.0.0",
"openpyxl", "openpyxl",
"cookiecutter", "cookiecutter",
"click", "click",
@ -56,7 +56,7 @@ dev = [
"python-dotenv", "python-dotenv",
"pytest-mock", "pytest-mock",
"unstructured[pdf]", "unstructured[pdf]",
"farm-haystack==1.19.0", # "farm-haystack==1.22.1",
"sentence_transformers", "sentence_transformers",
"cohere", "cohere",
"elasticsearch", "elasticsearch",

View File

@ -1,47 +1,66 @@
import os import os
from typing import List from typing import List
from theflow import Param from theflow import Node, Param
from theflow.utils.modules import ObjectInitDeclaration as _ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent, Document, LLMInterface
from kotaemon.contribs.promptui.logs import ResultLog
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import AzureOpenAI from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline from kotaemon.llms import AzureChatOpenAI
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
class QAResultLog(ResultLog):
@staticmethod
def _get_prompt(obj):
return obj["prompt"]
class QuestionAnsweringPipeline(BaseComponent): class QuestionAnsweringPipeline(BaseComponent):
retrieval_top_k: int = 1
llm: AzureOpenAI = AzureOpenAI.withx( _promptui_resultlog = QAResultLog
openai_api_base="https://bleh-dummy-2.openai.azure.com/", _promptui_outputs: list = [
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), {
"step": ".prompt",
"getter": "_get_prompt",
"component": "text",
"params": {"label": "Constructed prompt to LLM"},
},
{
"step": ".",
"getter": "_get_output",
"component": "text",
"params": {"label": "Answer"},
},
]
retrieval_top_k: int = 1
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
openai_api_version="2023-03-15-preview", openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35", deployment_name="dummy-q2-gpt35",
temperature=0, temperature=0,
request_timeout=60, request_timeout=60,
) )
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = ( retrieving_pipeline: VectorRetrieval = Node(
RetrieveDocumentFromVectorStorePipeline.withx( VectorRetrieval.withx(
vector_store=_(ChromaVectorStore).withx(path="./tmp"), vector_store=_(ChromaVectorStore).withx(path="./tmp"),
doc_store=_(SimpleFileDocumentStore).withx(path="docstore.json"),
embedding=AzureOpenAIEmbeddings.withx( embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding", deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/", azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
), ),
) ),
ignore_ui=True,
) )
def run_raw(self, text: str) -> str: def run(self, text: str) -> LLMInterface:
# reload the document store, in case it has been updated
doc_store = InMemoryDocumentStore()
doc_store.load("docstore.json")
self.retrieving_pipeline.doc_store = doc_store
# retrieve relevant documents as context # retrieve relevant documents as context
matched_texts: List[str] = [ matched_texts: List[str] = [
_.text _.text
@ -56,35 +75,33 @@ class QuestionAnsweringPipeline(BaseComponent):
return self.llm(prompt).text return self.llm(prompt).text
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline): class IndexingPipeline(VectorIndexing):
# Expose variables for users to switch in prompt ui
embedding_model: str = "text-embedding-ada-002"
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
@Param.auto()
def doc_store(self) -> InMemoryDocumentStore:
doc_store = InMemoryDocumentStore()
if os.path.isfile("docstore.json"):
doc_store.load("docstore.json")
return doc_store
vector_store: ChromaVectorStore = Param(
_(ChromaVectorStore).withx(path="./tmp"),
ignore_ui=True,
)
doc_store: SimpleFileDocumentStore = Param(
_(SimpleFileDocumentStore).withx(path="docstore.json"),
ignore_ui=True,
)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx( embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002", model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding", deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/", azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
) )
def run_raw(self, text: str) -> int: # type: ignore def run(self, text: str) -> Document:
"""Normally, this indexing pipeline returns nothing. For demonstration, """Normally, this indexing pipeline returns nothing. For demonstration,
we want it to return something, so let's return the number of documents we want it to return something, so let's return the number of documents
in the vector store in the vector store
""" """
super().run_raw(text) super().run(text)
if self.doc_store is not None: if self.doc_store is not None:
# persist to local anytime an indexing is created # persist to local anytime an indexing is created
# this can be bypassed when we have a FileDocumentStore # this can be bypassed when we have a FileDocumentStore
self.doc_store.save("docstore.json") self.doc_store.save("docstore.json")
return self.vector_store._collection.count() return Document(self.vector_store._collection.count())

View File

@ -3,7 +3,7 @@ from typing import List
from theflow.utils.modules import ObjectInitDeclaration as _ from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent from kotaemon.base import BaseComponent, LLMInterface
from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorRetrieval from kotaemon.indices import VectorRetrieval
from kotaemon.llms import AzureOpenAI from kotaemon.llms import AzureOpenAI
@ -30,9 +30,6 @@ class Pipeline(BaseComponent):
), ),
) )
def run_raw(self, text: str) -> str: def run(self, text: str) -> LLMInterface:
matched_texts: List[str] = self.retrieving_pipeline(text) matched_texts: List[str] = self.retrieving_pipeline(text)
return self.llm("\n".join(matched_texts)).text return self.llm("\n".join(matched_texts))
def run(self):
...