Upgrade promptui to conform to Gradio V4 (#98)
This commit is contained in:
parent
797df5a69c
commit
1f927d3391
|
@ -84,9 +84,7 @@ def handle_node(node: dict) -> dict:
|
||||||
|
|
||||||
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
|
def handle_input(pipeline: Union[BaseComponent, Type[BaseComponent]]) -> dict:
|
||||||
"""Get the input from the pipeline"""
|
"""Get the input from the pipeline"""
|
||||||
if not hasattr(pipeline, "run_raw"):
|
signature = inspect.signature(pipeline.run)
|
||||||
return {}
|
|
||||||
signature = inspect.signature(pipeline.run_raw)
|
|
||||||
inputs: Dict[str, Dict] = {}
|
inputs: Dict[str, Dict] = {}
|
||||||
for name, param in signature.parameters.items():
|
for name, param in signature.parameters.items():
|
||||||
if name in ["self", "args", "kwargs"]:
|
if name in ["self", "args", "kwargs"]:
|
||||||
|
|
|
@ -4,7 +4,7 @@ from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
import anyio
|
import anyio
|
||||||
from gradio import ChatInterface
|
from gradio import ChatInterface
|
||||||
from gradio.components import IOComponent, get_component_instance
|
from gradio.components import Component, get_component_instance
|
||||||
from gradio.events import on
|
from gradio.events import on
|
||||||
from gradio.helpers import special_args
|
from gradio.helpers import special_args
|
||||||
from gradio.routes import Request
|
from gradio.routes import Request
|
||||||
|
@ -20,7 +20,7 @@ class ChatBlock(ChatInterface):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
additional_outputs: str | IOComponent | list[str | IOComponent] | None = None,
|
additional_outputs: str | Component | list[str | Component] | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if additional_outputs:
|
if additional_outputs:
|
||||||
|
|
|
@ -44,7 +44,7 @@ class VectorIndexing(BaseIndexing):
|
||||||
qa_pipeline=CitationQAPipeline(**kwargs),
|
qa_pipeline=CitationQAPipeline(**kwargs),
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, text: str | list[str] | Document | list[Document]) -> None:
|
def run(self, text: str | list[str] | Document | list[Document]):
|
||||||
input_: list[Document] = []
|
input_: list[Document] = []
|
||||||
if not isinstance(text, list):
|
if not isinstance(text, list):
|
||||||
text = [text]
|
text = [text]
|
||||||
|
|
|
@ -111,7 +111,7 @@ class OpenAI(LCCompletionMixin, LLM):
|
||||||
openai_api_base: Optional[str] = None,
|
openai_api_base: Optional[str] = None,
|
||||||
model_name: str = "text-davinci-003",
|
model_name: str = "text-davinci-003",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_token: int = 256,
|
max_tokens: int = 256,
|
||||||
top_p: float = 1,
|
top_p: float = 1,
|
||||||
frequency_penalty: float = 0,
|
frequency_penalty: float = 0,
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
|
@ -126,7 +126,7 @@ class OpenAI(LCCompletionMixin, LLM):
|
||||||
openai_api_base=openai_api_base,
|
openai_api_base=openai_api_base,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_token=max_token,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
n=n,
|
n=n,
|
||||||
|
@ -154,7 +154,7 @@ class AzureOpenAI(LCCompletionMixin, LLM):
|
||||||
openai_api_key: Optional[str] = None,
|
openai_api_key: Optional[str] = None,
|
||||||
model_name: str = "text-davinci-003",
|
model_name: str = "text-davinci-003",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_token: int = 256,
|
max_tokens: int = 256,
|
||||||
top_p: float = 1,
|
top_p: float = 1,
|
||||||
frequency_penalty: float = 0,
|
frequency_penalty: float = 0,
|
||||||
n: int = 1,
|
n: int = 1,
|
||||||
|
@ -171,7 +171,7 @@ class AzureOpenAI(LCCompletionMixin, LLM):
|
||||||
openai_api_key=openai_api_key,
|
openai_api_key=openai_api_key,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_token=max_token,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
n=n,
|
n=n,
|
||||||
|
|
|
@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
|
||||||
# metadata and dependencies
|
# metadata and dependencies
|
||||||
[project]
|
[project]
|
||||||
name = "kotaemon"
|
name = "kotaemon"
|
||||||
version = "0.3.4"
|
version = "0.3.5"
|
||||||
requires-python = ">= 3.10"
|
requires-python = ">= 3.10"
|
||||||
description = "Kotaemon core library for AI development."
|
description = "Kotaemon core library for AI development."
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
@ -19,7 +19,7 @@ dependencies = [
|
||||||
"theflow",
|
"theflow",
|
||||||
"llama-index>=0.9.0",
|
"llama-index>=0.9.0",
|
||||||
"llama-hub",
|
"llama-hub",
|
||||||
"gradio",
|
"gradio>=4.0.0",
|
||||||
"openpyxl",
|
"openpyxl",
|
||||||
"cookiecutter",
|
"cookiecutter",
|
||||||
"click",
|
"click",
|
||||||
|
@ -56,7 +56,7 @@ dev = [
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"pytest-mock",
|
"pytest-mock",
|
||||||
"unstructured[pdf]",
|
"unstructured[pdf]",
|
||||||
"farm-haystack==1.19.0",
|
# "farm-haystack==1.22.1",
|
||||||
"sentence_transformers",
|
"sentence_transformers",
|
||||||
"cohere",
|
"cohere",
|
||||||
"elasticsearch",
|
"elasticsearch",
|
||||||
|
|
|
@ -1,47 +1,66 @@
|
||||||
import os
|
import os
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from theflow import Param
|
from theflow import Node, Param
|
||||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent, Document, LLMInterface
|
||||||
|
from kotaemon.contribs.promptui.logs import ResultLog
|
||||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||||
from kotaemon.llms import AzureOpenAI
|
from kotaemon.indices import VectorIndexing, VectorRetrieval
|
||||||
from kotaemon.pipelines.indexing import IndexVectorStoreFromDocumentPipeline
|
from kotaemon.llms import AzureChatOpenAI
|
||||||
from kotaemon.pipelines.retrieving import RetrieveDocumentFromVectorStorePipeline
|
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
|
||||||
from kotaemon.storages import ChromaVectorStore, InMemoryDocumentStore
|
|
||||||
|
|
||||||
|
class QAResultLog(ResultLog):
|
||||||
|
@staticmethod
|
||||||
|
def _get_prompt(obj):
|
||||||
|
return obj["prompt"]
|
||||||
|
|
||||||
|
|
||||||
class QuestionAnsweringPipeline(BaseComponent):
|
class QuestionAnsweringPipeline(BaseComponent):
|
||||||
retrieval_top_k: int = 1
|
|
||||||
|
|
||||||
llm: AzureOpenAI = AzureOpenAI.withx(
|
_promptui_resultlog = QAResultLog
|
||||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
_promptui_outputs: list = [
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
{
|
||||||
|
"step": ".prompt",
|
||||||
|
"getter": "_get_prompt",
|
||||||
|
"component": "text",
|
||||||
|
"params": {"label": "Constructed prompt to LLM"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"step": ".",
|
||||||
|
"getter": "_get_output",
|
||||||
|
"component": "text",
|
||||||
|
"params": {"label": "Answer"},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
retrieval_top_k: int = 1
|
||||||
|
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
|
||||||
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
|
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||||
openai_api_version="2023-03-15-preview",
|
openai_api_version="2023-03-15-preview",
|
||||||
deployment_name="dummy-q2-gpt35",
|
deployment_name="dummy-q2-gpt35",
|
||||||
temperature=0,
|
temperature=0,
|
||||||
request_timeout=60,
|
request_timeout=60,
|
||||||
)
|
)
|
||||||
|
|
||||||
retrieving_pipeline: RetrieveDocumentFromVectorStorePipeline = (
|
retrieving_pipeline: VectorRetrieval = Node(
|
||||||
RetrieveDocumentFromVectorStorePipeline.withx(
|
VectorRetrieval.withx(
|
||||||
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
|
vector_store=_(ChromaVectorStore).withx(path="./tmp"),
|
||||||
|
doc_store=_(SimpleFileDocumentStore).withx(path="docstore.json"),
|
||||||
embedding=AzureOpenAIEmbeddings.withx(
|
embedding=AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="dummy-q2-text-embedding",
|
deployment="dummy-q2-text-embedding",
|
||||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||||
),
|
),
|
||||||
)
|
),
|
||||||
|
ignore_ui=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_raw(self, text: str) -> str:
|
def run(self, text: str) -> LLMInterface:
|
||||||
# reload the document store, in case it has been updated
|
|
||||||
doc_store = InMemoryDocumentStore()
|
|
||||||
doc_store.load("docstore.json")
|
|
||||||
self.retrieving_pipeline.doc_store = doc_store
|
|
||||||
|
|
||||||
# retrieve relevant documents as context
|
# retrieve relevant documents as context
|
||||||
matched_texts: List[str] = [
|
matched_texts: List[str] = [
|
||||||
_.text
|
_.text
|
||||||
|
@ -56,35 +75,33 @@ class QuestionAnsweringPipeline(BaseComponent):
|
||||||
return self.llm(prompt).text
|
return self.llm(prompt).text
|
||||||
|
|
||||||
|
|
||||||
class IndexingPipeline(IndexVectorStoreFromDocumentPipeline):
|
class IndexingPipeline(VectorIndexing):
|
||||||
# Expose variables for users to switch in prompt ui
|
|
||||||
embedding_model: str = "text-embedding-ada-002"
|
|
||||||
vector_store: _[ChromaVectorStore] = _(ChromaVectorStore).withx(path="./tmp")
|
|
||||||
|
|
||||||
@Param.auto()
|
|
||||||
def doc_store(self) -> InMemoryDocumentStore:
|
|
||||||
doc_store = InMemoryDocumentStore()
|
|
||||||
if os.path.isfile("docstore.json"):
|
|
||||||
doc_store.load("docstore.json")
|
|
||||||
return doc_store
|
|
||||||
|
|
||||||
|
vector_store: ChromaVectorStore = Param(
|
||||||
|
_(ChromaVectorStore).withx(path="./tmp"),
|
||||||
|
ignore_ui=True,
|
||||||
|
)
|
||||||
|
doc_store: SimpleFileDocumentStore = Param(
|
||||||
|
_(SimpleFileDocumentStore).withx(path="docstore.json"),
|
||||||
|
ignore_ui=True,
|
||||||
|
)
|
||||||
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
embedding: AzureOpenAIEmbeddings = AzureOpenAIEmbeddings.withx(
|
||||||
model="text-embedding-ada-002",
|
model="text-embedding-ada-002",
|
||||||
deployment="dummy-q2-text-embedding",
|
deployment="dummy-q2-text-embedding",
|
||||||
openai_api_base="https://bleh-dummy-2.openai.azure.com/",
|
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
|
||||||
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
|
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_raw(self, text: str) -> int: # type: ignore
|
def run(self, text: str) -> Document:
|
||||||
"""Normally, this indexing pipeline returns nothing. For demonstration,
|
"""Normally, this indexing pipeline returns nothing. For demonstration,
|
||||||
we want it to return something, so let's return the number of documents
|
we want it to return something, so let's return the number of documents
|
||||||
in the vector store
|
in the vector store
|
||||||
"""
|
"""
|
||||||
super().run_raw(text)
|
super().run(text)
|
||||||
|
|
||||||
if self.doc_store is not None:
|
if self.doc_store is not None:
|
||||||
# persist to local anytime an indexing is created
|
# persist to local anytime an indexing is created
|
||||||
# this can be bypassed when we have a FileDocumentStore
|
# this can be bypassed when we have a FileDocumentStore
|
||||||
self.doc_store.save("docstore.json")
|
self.doc_store.save("docstore.json")
|
||||||
|
|
||||||
return self.vector_store._collection.count()
|
return Document(self.vector_store._collection.count())
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import List
|
||||||
|
|
||||||
from theflow.utils.modules import ObjectInitDeclaration as _
|
from theflow.utils.modules import ObjectInitDeclaration as _
|
||||||
|
|
||||||
from kotaemon.base import BaseComponent
|
from kotaemon.base import BaseComponent, LLMInterface
|
||||||
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
from kotaemon.embeddings import AzureOpenAIEmbeddings
|
||||||
from kotaemon.indices import VectorRetrieval
|
from kotaemon.indices import VectorRetrieval
|
||||||
from kotaemon.llms import AzureOpenAI
|
from kotaemon.llms import AzureOpenAI
|
||||||
|
@ -30,9 +30,6 @@ class Pipeline(BaseComponent):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def run_raw(self, text: str) -> str:
|
def run(self, text: str) -> LLMInterface:
|
||||||
matched_texts: List[str] = self.retrieving_pipeline(text)
|
matched_texts: List[str] = self.retrieving_pipeline(text)
|
||||||
return self.llm("\n".join(matched_texts)).text
|
return self.llm("\n".join(matched_texts))
|
||||||
|
|
||||||
def run(self):
|
|
||||||
...
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user