Upgrade promptui to conform to Gradio V4 (#98)

This commit is contained in:
Duc Nguyen (john) 2023-12-07 15:24:07 +07:00 committed by GitHub
parent 797df5a69c
commit 1f927d3391
7 changed files with 68 additions and 56 deletions

View File

@ -84,9 +84,7 @@ def handle_node(node: dict) -> dict:
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
"""Get the input from the pipeline"""
if not hasattr(pipeline, "run_raw"):
return {}
signature = inspect.signature(pipeline.run_raw)
signature = inspect.signature(pipeline.run)
inputs: Dict[str, Dict] = {}
for name, param in signature.parameters.items():
if name in ["self", "args", "kwargs"]:

View File

@ -4,7 +4,7 @@ from typing import Any, AsyncGenerator
import anyio
from gradio import ChatInterface
from gradio.components import IOComponent, get_component_instance
from gradio.components import Component, get_component_instance
from gradio.events import on
from gradio.helpers import special_args
from gradio.routes import Request
@ -20,7 +20,7 @@ class ChatBlock(ChatInterface):
def __init__(
self,
*args,
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
additional_outputs: str | Component | list[str | Component] | None = None,
**kwargs,
):
if additional_outputs:

View File

@ -44,7 +44,7 @@ class VectorIndexing(BaseIndexing):
qa_pipeline=CitationQAPipeline(**kwargs),
)
def run(self, text: str | list[str] | Document | list[Document]) -> None:
def run(self, text: str | list[str] | Document | list[Document]):
input_: list[Document] = []
if not isinstance(text, list):
text = [text]

View File

@ -111,7 +111,7 @@ class OpenAI(LCCompletionMixin, LLM):
openai_api_base: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_token: int = 256,
max_tokens: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
@ -126,7 +126,7 @@ class OpenAI(LCCompletionMixin, LLM):
openai_api_base=openai_api_base,
model_name=model_name,
temperature=temperature,
max_token=max_token,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,
@ -154,7 +154,7 @@ class AzureOpenAI(LCCompletionMixin, LLM):
openai_api_key: Optional[str] = None,
model_name: str = "text-davinci-003",
temperature: float = 0.7,
max_token: int = 256,
max_tokens: int = 256,
top_p: float = 1,
frequency_penalty: float = 0,
n: int = 1,
@ -171,7 +171,7 @@ class AzureOpenAI(LCCompletionMixin, LLM):
openai_api_key=openai_api_key,
model_name=model_name,
temperature=temperature,
max_token=max_token,
max_tokens=max_tokens,
top_p=top_p,
frequency_penalty=frequency_penalty,
n=n,

View File

@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies
[project]
name = "kotaemon"
version = "0.3.4"
version = "0.3.5"
requires-python = ">= 3.10"
description = "Kotaemon core library for AI development."
dependencies = [
@ -19,7 +19,7 @@ dependencies = [
"theflow",
"llama-index>=0.9.0",
"llama-hub",
"gradio",
"gradio>=4.0.0",
"openpyxl",
"cookiecutter",
"click",
@ -56,7 +56,7 @@ dev = [
"python-dotenv",
"pytest-mock",
"unstructured[pdf]",
"farm-haystack==1.19.0",
# "farm-haystack==1.22.1",
"sentence_transformers",
"cohere",
"elasticsearch",

View File

@ -1,47 +1,66 @@
import os
from typing import List
from theflow import Param
from theflow import Node, Param
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.base import BaseComponent, Document, LLMInterface
from kotaemon.contribs.promptui.logs import ResultLog
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.llms import AzureOpenAI
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.llms import AzureChatOpenAI
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
class QAResultLog(ResultLog):
@staticmethod
def _get_prompt(obj):
return obj["prompt"]
class QuestionAnsweringPipeline(BaseComponent):
retrieval_top_k: int = 1
llm: AzureOpenAI = AzureOpenAI.withx(
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
_promptui_resultlog = QAResultLog
_promptui_outputs: list = [
{
"step": ".prompt",
"getter": "_get_prompt",
"component": "text",
"params": {"label": "Constructed prompt to LLM"},
},
{
"step": ".",
"getter": "_get_output",
"component": "text",
"params": {"label": "Answer"},
},
]
retrieval_top_k: int = 1
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
openai_api_version="2023-03-15-preview",
deployment_name="dummy-q2-gpt35",
temperature=0,
request_timeout=60,
)
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
RetrieveDocumentFromVectorStorePipeline.withx(
retrieving_pipeline: VectorRetrieval = Node(
VectorRetrieval.withx(
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
doc_store=_(SimpleFileDocumentStore).withx(path="docstore.json"),
embedding=AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
),
)
),
ignore_ui=True,
)
def run_raw(self, text: str) -> str:
# reload the document store, in case it has been updated
doc_store = InMemoryDocumentStore()
doc_store.load("docstore.json")
self.retrieving_pipeline.doc_store = doc_store
def run(self, text: str) -> LLMInterface:
# retrieve relevant documents as context
matched_texts: List[str] = [
_.text
@ -56,35 +75,33 @@ class QuestionAnsweringPipeline(BaseComponent):
return self.llm(prompt).text
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
# Expose variables for users to switch in prompt ui
embedding_model: str = "text-embedding-ada-002"
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
@Param.auto()
def doc_store(self) -> InMemoryDocumentStore:
doc_store = InMemoryDocumentStore()
if os.path.isfile("docstore.json"):
doc_store.load("docstore.json")
return doc_store
class IndexingPipeline(VectorIndexing):
vector_store: ChromaVectorStore = Param(
_(ChromaVectorStore).withx(path="./tmp"),
ignore_ui=True,
)
doc_store: SimpleFileDocumentStore = Param(
_(SimpleFileDocumentStore).withx(path="docstore.json"),
ignore_ui=True,
)
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
model="text-embedding-ada-002",
deployment="dummy-q2-text-embedding",
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
)
def run_raw(self, text: str) -> int: # type: ignore
def run(self, text: str) -> Document:
"""Normally, this indexing pipeline returns nothing. For demonstration,
we want it to return something, so let's return the number of documents
in the vector store
"""
super().run_raw(text)
super().run(text)
if self.doc_store is not None:
# persist to local anytime an indexing is created
# this can be bypassed when we have a FileDocumentStore
self.doc_store.save("docstore.json")
return self.vector_store._collection.count()
return Document(self.vector_store._collection.count())

View File

@ -3,7 +3,7 @@ from typing import List
from theflow.utils.modules import ObjectInitDeclaration as _
from kotaemon.base import BaseComponent
from kotaemon.base import BaseComponent, LLMInterface
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorRetrieval
from kotaemon.llms import AzureOpenAI
@ -30,9 +30,6 @@ class Pipeline(BaseComponent):
),
)
def run_raw(self, text: str) -> str:
def run(self, text: str) -> LLMInterface:
matched_texts: List[str] = self.retrieving_pipeline(text)
return self.llm("\n".join(matched_texts)).text
def run(self):
...
return self.llm("\n".join(matched_texts))