Allow users to add LLM within the UI (#6)

* Rename AzureChatOpenAI to LCAzureChatOpenAI
* Provide vanilla ChatOpenAI and AzureChatOpenAI
* Remove the highest accuracy, lowest cost criteria

These criteria are unnecessary. The users, not pipeline creators, should choose
which LLM to use. Furthermore, it's cumbersome to input this information,
really degrades user experience.

* Remove the LLM selection in simple reasoning pipeline
* Provide a dedicated stream method to generate the output
* Return placeholder message to chat if the text is empty
This commit is contained in:
Duc Nguyen (john) 2024-04-06 11:53:17 +07:00 committed by GitHub
parent e187e23dd1
commit a203fc0f7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 1339 additions and 169 deletions

1
.gitignore vendored
View File

@ -458,6 +458,7 @@ logs/
.gitsecret/keys/random_seed
!*.secret
.envrc
.env
S.gpg-agent*
.vscode/settings.json

View File

@ -22,7 +22,7 @@ The syntax of a component is as follow:
```python
from kotaemon.base import BaseComponent
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
from kotaemon.parsers import RegexExtractor
@ -32,7 +32,7 @@ class FancyPipeline(BaseComponent):
param3: float
node1: BaseComponent # this is a node because of BaseComponent type annotation
node2: AzureChatOpenAI # this is also a node because AzureChatOpenAI subclasses BaseComponent
node2: LCAzureChatOpenAI # this is also a node because LCAzureChatOpenAI subclasses BaseComponent
node3: RegexExtractor # this is also a node bceause RegexExtractor subclasses BaseComponent
def run(self, some_text: str):
@ -45,7 +45,7 @@ class FancyPipeline(BaseComponent):
Then this component can be used as follow:
```python
llm = AzureChatOpenAI(endpoint="some-endpont")
llm = LCAzureChatOpenAI(endpoint="some-endpont")
extractor = RegexExtractor(pattern=["yes", "Yes"])
component = FancyPipeline(

View File

@ -193,7 +193,8 @@ information panel.
You can access users' collections of LLMs and embedding models with:
```python
from ktem.components import llms, embeddings
from ktem.components import embeddings
from ktem.llms.manager import llms
llm = llms.get_default()
@ -206,12 +207,12 @@ models they want to use through the settings.
```python
@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms
from ktem.llms.manager import llms
return {
"citation_llm": {
"name": "LLM for citation",
"value": llms.get_lowest_cost_name(),
"value": llms.get_default(),
"component: "dropdown",
"choices": list(llms.options().keys()),
},

View File

@ -52,7 +52,7 @@ class BaseComponent(Function):
def stream(self, *args, **kwargs) -> Iterator[Document] | None:
...
async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None:
...
@abstractmethod

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar
from langchain.schema.messages import AIMessage as LCAIMessage
from langchain.schema.messages import HumanMessage as LCHumanMessage
@ -10,6 +10,9 @@ from llama_index.schema import Document as BaseDocument
if TYPE_CHECKING:
from haystack.schema import Document as HaystackDocument
from openai.types.chat.chat_completion_message_param import (
ChatCompletionMessageParam,
)
IO_Type = TypeVar("IO_Type", "Document", str)
SAMPLE_TEXT = "A sample Document from kotaemon"
@ -26,10 +29,15 @@ class Document(BaseDocument):
Attributes:
content: raw content of the document, can be anything
source: id of the source of the Document. Optional.
channel: the channel to show the document. Optional.:
- chat: show in chat message
- info: show in information panel
- debug: show in debug panel
"""
content: Any
content: Any = None
source: Optional[str] = None
channel: Optional[Literal["chat", "info", "debug"]] = None
def __init__(self, content: Optional[Any] = None, *args, **kwargs):
if content is None:
@ -87,17 +95,23 @@ class BaseMessage(Document):
def __add__(self, other: Any):
raise NotImplementedError
def to_openai_format(self) -> "ChatCompletionMessageParam":
raise NotImplementedError
class SystemMessage(BaseMessage, LCSystemMessage):
pass
def to_openai_format(self) -> "ChatCompletionMessageParam":
return {"role": "system", "content": self.content}
class AIMessage(BaseMessage, LCAIMessage):
pass
def to_openai_format(self) -> "ChatCompletionMessageParam":
return {"role": "assistant", "content": self.content}
class HumanMessage(BaseMessage, LCHumanMessage):
pass
def to_openai_format(self) -> "ChatCompletionMessageParam":
return {"role": "user", "content": self.content}
class RetrievedDocument(Document):

View File

@ -1,7 +1,7 @@
import os
from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument
from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate
from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate
from .citation import CitationPipeline
@ -13,7 +13,7 @@ class CitationQAPipeline(BaseComponent):
'Answer the following question: "{question}". '
"The context is: \n{context}\nAnswer: "
)
llm: BaseLLM = AzureChatOpenAI.withx(
llm: BaseLLM = LCAzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", ""),
openai_api_version="2023-07-01-preview",

View File

@ -2,7 +2,15 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes
from .base import BaseLLM
from .branching import GatedBranchingPipeline, SimpleBranchingPipeline
from .chats import AzureChatOpenAI, ChatLLM, ChatOpenAI, EndpointChatLLM, LlamaCppChat
from .chats import (
AzureChatOpenAI,
ChatLLM,
ChatOpenAI,
EndpointChatLLM,
LCAzureChatOpenAI,
LCChatOpenAI,
LlamaCppChat,
)
from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI
from .cot import ManualSequentialChainOfThought, Thought
from .linear import GatedLinearPipeline, SimpleLinearPipeline
@ -17,8 +25,10 @@ __all__ = [
"HumanMessage",
"AIMessage",
"SystemMessage",
"ChatOpenAI",
"AzureChatOpenAI",
"ChatOpenAI",
"LCAzureChatOpenAI",
"LCChatOpenAI",
"LlamaCppChat",
# completion-specific components
"LLM",

View File

@ -18,5 +18,8 @@ class BaseLLM(BaseComponent):
def stream(self, *args, **kwargs) -> Iterator[LLMInterface]:
raise NotImplementedError
async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]:
raise NotImplementedError
def run(self, *args, **kwargs):
return self.invoke(*args, **kwargs)

View File

@ -15,7 +15,7 @@ class SimpleBranchingPipeline(BaseComponent):
Example:
```python
from kotaemon.llms import (
AzureChatOpenAI,
LCAzureChatOpenAI,
BasePromptComponent,
GatedLinearPipeline,
)
@ -25,7 +25,7 @@ class SimpleBranchingPipeline(BaseComponent):
return x
pipeline = SimpleBranchingPipeline()
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
@ -92,7 +92,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
Example:
```python
from kotaemon.llms import (
AzureChatOpenAI,
LCAzureChatOpenAI,
BasePromptComponent,
GatedLinearPipeline,
)
@ -102,7 +102,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
return x
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
@ -157,7 +157,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline):
if __name__ == "__main__":
import dotenv
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.llms import BasePromptComponent, LCAzureChatOpenAI
from kotaemon.parsers import RegexExtractor
def identity(x):
@ -166,7 +166,7 @@ if __name__ == "__main__":
secrets = dotenv.dotenv_values(".env")
pipeline = GatedBranchingPipeline()
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base=secrets.get("OPENAI_API_BASE", ""),
openai_api_key=secrets.get("OPENAI_API_KEY", ""),
openai_api_version=secrets.get("OPENAI_API_VERSION", ""),

View File

@ -1,13 +1,17 @@
from .base import ChatLLM
from .endpoint_based import EndpointChatLLM
from .langchain_based import AzureChatOpenAI, ChatOpenAI, LCChatMixin
from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI
from .llamacpp import LlamaCppChat
from .openai import AzureChatOpenAI, ChatOpenAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
"ChatLLM",
"EndpointChatLLM",
"ChatOpenAI",
"AzureChatOpenAI",
"LCChatOpenAI",
"LCAzureChatOpenAI",
"LCChatMixin",
"LlamaCppChat",
]

View File

@ -5,6 +5,7 @@ from kotaemon.base import (
BaseMessage,
HumanMessage,
LLMInterface,
Param,
SystemMessage,
)
@ -20,7 +21,9 @@ class EndpointChatLLM(ChatLLM):
endpoint_url (str): The url of a OpenAI API compatible endpoint.
"""
endpoint_url: str
endpoint_url: str = Param(
help="URL of the OpenAI API compatible endpoint", required=True
)
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs

View File

@ -165,7 +165,7 @@ class LCChatMixin:
raise ValueError(f"Invalid param {path}")
class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
class LCChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__(
self,
openai_api_base: str | None = None,
@ -193,7 +193,7 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
return ChatOpenAI
class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore
def __init__(
self,
azure_endpoint: str | None = None,

View File

@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Iterator, Optional, cast
from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param
@ -12,13 +12,32 @@ if TYPE_CHECKING:
class LlamaCppChat(ChatLLM):
"""Wrapper around the llama-cpp-python's Llama model"""
model_path: Optional[str] = None
chat_format: Optional[str] = None
lora_base: Optional[str] = None
n_ctx: int = 512
n_gpu_layers: int = 0
use_mmap: bool = True
vocab_only: bool = False
model_path: str = Param(
help="Path to the model file. This is required to load the model.",
required=True,
)
chat_format: str = Param(
help=(
"Chat format to use. Please refer to llama_cpp.llama_chat_format for a "
"list of supported formats. If blank, the chat format will be auto-"
"inferred."
),
required=True,
)
lora_base: Optional[str] = Param(None, help="Path to the base Lora model")
n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model")
n_gpu_layers: Optional[int] = Param(
0,
help=("Number of layers to offload to GPU. If -1, all layers are offloaded"),
)
use_mmap: Optional[bool] = Param(
True,
help=(),
)
vocab_only: Optional[bool] = Param(
False,
help=("If True, only the vocabulary is loaded. This is useful for debugging."),
)
_role_mapper: dict[str, str] = {
"human": "user",
@ -60,9 +79,9 @@ class LlamaCppChat(ChatLLM):
vocab_only=self.vocab_only,
)
def run(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
def prepare_message(
self, messages: str | BaseMessage | list[BaseMessage]
) -> list[dict]:
input_: list[BaseMessage] = []
if isinstance(messages, str):
@ -72,11 +91,19 @@ class LlamaCppChat(ChatLLM):
else:
input_ = messages
pred: "CCCR" = self.client_object.create_chat_completion(
messages=[
output_ = [
{"role": self._role_mapper[each.type], "content": each.content}
for each in input_
], # type: ignore
]
return output_
def invoke(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> LLMInterface:
pred: "CCCR" = self.client_object.create_chat_completion(
messages=self.prepare_message(messages),
stream=False,
)
@ -91,3 +118,19 @@ class LlamaCppChat(ChatLLM):
total_tokens=pred["usage"]["total_tokens"],
prompt_tokens=pred["usage"]["prompt_tokens"],
)
def stream(
self, messages: str | BaseMessage | list[BaseMessage], **kwargs
) -> Iterator[LLMInterface]:
pred = self.client_object.create_chat_completion(
messages=self.prepare_message(messages),
stream=True,
)
for chunk in pred:
if not chunk["choices"]:
continue
if "content" not in chunk["choices"][0]["delta"]:
continue
yield LLMInterface(content=chunk["choices"][0]["delta"]["content"])

View File

@ -0,0 +1,356 @@
from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional
from theflow.utils.modules import import_dotted_string
from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param
from .base import ChatLLM
if TYPE_CHECKING:
from openai.types.chat.chat_completion_message_param import (
ChatCompletionMessageParam,
)
class BaseChatOpenAI(ChatLLM):
"""Base interface for OpenAI chat model, using the openai library
This class exposes the parameters in resources.Chat. To subclass this class:
- Implement the `prepare_client` method to return the OpenAI client
- Implement the `openai_response` method to return the OpenAI response
- Implement the params relate to the OpenAI client
"""
_dependencies = ["openai"]
_capabilities = ["chat", "text"] # consider as mixin
api_key: str = Param(help="API key", required=True)
timeout: Optional[float] = Param(None, help="Timeout for the API request")
max_retries: Optional[int] = Param(
None, help="Maximum number of retries for the API request"
)
temperature: Optional[float] = Param(
None,
help=(
"Number between 0 and 2 that controls the randomness of the generated "
"tokens. Lower values make the model more deterministic, while higher "
"values make the model more random."
),
)
max_tokens: Optional[int] = Param(
None,
help=(
"Maximum number of tokens to generate. The total length of input tokens "
"and generated tokens is limited by the model's context length."
),
)
n: int = Param(
1,
help=(
"Number of completions to generate. The API will generate n completion "
"for each prompt."
),
)
stop: Optional[str | list[str]] = Param(
None,
help=(
"Stop sequence. If a stop sequence is detected, generation will stop "
"at that point. If not specified, generation will continue until the "
"maximum token length is reached."
),
)
frequency_penalty: Optional[float] = Param(
None,
help=(
"Number between -2.0 and 2.0. Positive values penalize new tokens "
"based on their existing frequency in the text so far, decrearsing the "
"model's likelihood of repeating the same text."
),
)
presence_penalty: Optional[float] = Param(
None,
help=(
"Number between -2.0 and 2.0. Positive values penalize new tokens "
"based on their existing presence in the text so far, decrearsing the "
"model's likelihood of repeating the same text."
),
)
tool_choice: Optional[str] = Param(
None,
help=(
"Choice of tool to use for the completion. Available choices are: "
"auto, default."
),
)
tools: Optional[list[str]] = Param(
None,
help="List of tools to use for the completion.",
)
logprobs: Optional[bool] = Param(
None,
help=(
"Include log probabilities on the logprobs most likely tokens, "
"as well as the chosen token."
),
)
logit_bias: Optional[dict] = Param(
None,
help=(
"Dictionary of logit bias values to add to the logits of the tokens "
"in the vocabulary."
),
)
top_logprobs: Optional[int] = Param(
None,
help=(
"An integer between 0 and 5 specifying the number of most likely tokens "
"to return at each token position, each with an associated log "
"probability. `logprobs` must also be set to `true` if this parameter "
"is used."
),
)
top_p: Optional[float] = Param(
None,
help=(
"An alternative to sampling with temperature, called nucleus sampling, "
"where the model considers the results of the token with top_p "
"probability mass. So 0.1 means that only the tokens comprising the "
"top 10% probability mass are considered."
),
)
@Param.auto(depends_on=["max_retries"])
def max_retries_(self):
if self.max_retries is None:
from openai._constants import DEFAULT_MAX_RETRIES
return DEFAULT_MAX_RETRIES
return self.max_retries
def prepare_message(
self, messages: str | BaseMessage | list[BaseMessage]
) -> list["ChatCompletionMessageParam"]:
"""Prepare the message into OpenAI format
Returns:
list[dict]: List of messages in OpenAI format
"""
input_: list[BaseMessage] = []
output_: list["ChatCompletionMessageParam"] = []
if isinstance(messages, str):
input_ = [HumanMessage(content=messages)]
elif isinstance(messages, BaseMessage):
input_ = [messages]
else:
input_ = messages
for message in input_:
output_.append(message.to_openai_format())
return output_
def prepare_client(self, async_version: bool = False):
"""Get the OpenAI client
Args:
async_version (bool): Whether to get the async version of the client
"""
raise NotImplementedError
def openai_response(self, client, **kwargs):
"""Get the openai response"""
raise NotImplementedError
def invoke(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
) -> LLMInterface:
client = self.prepare_client(async_version=False)
input_messages = self.prepare_message(messages)
resp = self.openai_response(
client, messages=input_messages, stream=False, **kwargs
).dict()
output = LLMInterface(
candidates=[_["message"]["content"] for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"],
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
],
)
return output
async def ainvoke(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
) -> LLMInterface:
client = self.prepare_client(async_version=True)
input_messages = self.prepare_message(messages)
resp = await self.openai_response(
client, messages=input_messages, stream=False, **kwargs
).dict()
output = LLMInterface(
candidates=[_["message"]["content"] for _ in resp["choices"]],
content=resp["choices"][0]["message"]["content"],
total_tokens=resp["usage"]["total_tokens"],
prompt_tokens=resp["usage"]["prompt_tokens"],
completion_tokens=resp["usage"]["completion_tokens"],
messages=[
AIMessage(content=_["message"]["content"]) for _ in resp["choices"]
],
)
return output
def stream(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
) -> Iterator[LLMInterface]:
client = self.prepare_client(async_version=False)
input_messages = self.prepare_message(messages)
resp = self.openai_response(
client, messages=input_messages, stream=True, **kwargs
)
for chunk in resp:
if not chunk.choices:
continue
if chunk.choices[0].delta.content is not None:
yield LLMInterface(content=chunk.choices[0].delta.content)
async def astream(
self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs
) -> AsyncGenerator[LLMInterface, None]:
client = self.prepare_client(async_version=True)
input_messages = self.prepare_message(messages)
resp = self.openai_response(
client, messages=input_messages, stream=True, **kwargs
)
async for chunk in resp:
if not chunk.choices:
continue
if chunk.choices[0].delta.content is not None:
yield LLMInterface(content=chunk.choices[0].delta.content)
class ChatOpenAI(BaseChatOpenAI):
"""OpenAI chat model"""
base_url: Optional[str] = Param(None, help="OpenAI base URL")
organization: Optional[str] = Param(None, help="OpenAI organization")
model: str = Param(help="OpenAI model", required=True)
def prepare_client(self, async_version: bool = False):
"""Get the OpenAI client
Args:
async_version (bool): Whether to get the async version of the client
"""
params = {
"api_key": self.api_key,
"organization": self.organization,
"base_url": self.base_url,
"timeout": self.timeout,
"max_retries": self.max_retries_,
}
if async_version:
from openai import AsyncOpenAI
return AsyncOpenAI(**params)
from openai import OpenAI
return OpenAI(**params)
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = {
"model": self.model,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": self.n,
"stop": self.stop,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"tool_choice": self.tool_choice,
"tools": self.tools,
"logprobs": self.logprobs,
"logit_bias": self.logit_bias,
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
}
params.update(kwargs)
return client.chat.completions.create(**params)
class AzureChatOpenAI(BaseChatOpenAI):
"""OpenAI chat model provided by Microsoft Azure"""
azure_endpoint: str = Param(
help=(
"HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, "
"azure_deployment, and api_version parameters are used to construct "
"the full URL for the Azure OpenAI model."
)
)
azure_deployment: str = Param(help="Azure deployment name", required=True)
api_version: str = Param(help="Azure model version", required=True)
azure_ad_token: Optional[str] = Param(None, help="Azure AD token")
azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider")
@Param.auto(depends_on=["azure_ad_token_provider"])
def azure_ad_token_provider_(self):
if isinstance(self.azure_ad_token_provider, str):
return import_dotted_string(self.azure_ad_token_provider, safe=False)
def prepare_client(self, async_version: bool = False):
"""Get the OpenAI client
Args:
async_version (bool): Whether to get the async version of the client
"""
params = {
"azure_endpoint": self.azure_endpoint,
"api_version": self.api_version,
"api_key": self.api_key,
"azure_ad_token": self.azure_ad_token,
"azure_ad_token_provider": self.azure_ad_token_provider_,
"timeout": self.timeout,
"max_retries": self.max_retries_,
}
if async_version:
from openai import AsyncAzureOpenAI
return AsyncAzureOpenAI(**params)
from openai import AzureOpenAI
return AzureOpenAI(**params)
def openai_response(self, client, **kwargs):
"""Get the openai response"""
params = {
"model": self.azure_deployment,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": self.n,
"stop": self.stop,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"tool_choice": self.tool_choice,
"tools": self.tools,
"logprobs": self.logprobs,
"logit_bias": self.logit_bias,
"top_logprobs": self.top_logprobs,
"top_p": self.top_p,
}
params.update(kwargs)
return client.chat.completions.create(**params)

View File

@ -5,7 +5,7 @@ from theflow import Function, Node, Param
from kotaemon.base import BaseComponent, Document
from .chats import AzureChatOpenAI
from .chats import LCAzureChatOpenAI
from .completions import LLM
from .prompts import BasePromptComponent
@ -25,7 +25,7 @@ class Thought(BaseComponent):
>> from kotaemon.pipelines.cot import Thought
>> thought = Thought(
prompt="How to {action} {object}?",
llm=AzureChatOpenAI(...),
llm=LCAzureChatOpenAI(...),
post_process=lambda string: {"tutorial": string},
)
>> output = thought(action="install", object="python")
@ -42,7 +42,7 @@ class Thought(BaseComponent):
This `Thought` allows chaining sequentially with the + operator. For example:
```python
>> llm = AzureChatOpenAI(...)
>> llm = LCAzureChatOpenAI(...)
>> thought1 = Thought(
prompt="Word {word} in {language} is ",
llm=llm,
@ -73,7 +73,7 @@ class Thought(BaseComponent):
" component is executed"
)
)
llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt")
llm: LLM = Node(LCAzureChatOpenAI, help="The LLM model to execute the input prompt")
post_process: Function = Node(
help=(
"The function post-processor that post-processes LLM output prediction ."
@ -117,7 +117,7 @@ class ManualSequentialChainOfThought(BaseComponent):
```pycon
>>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought
>>> llm = AzureChatOpenAI(...)
>>> llm = LCAzureChatOpenAI(...)
>>> thought1 = Thought(
>>> prompt="Word {word} in {language} is ",
>>> post_process=lambda string: {"translated": string},

View File

@ -22,12 +22,12 @@ class SimpleLinearPipeline(BaseComponent):
Example Usage:
```python
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
def identity(x):
return x
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",
@ -89,13 +89,13 @@ class GatedLinearPipeline(SimpleLinearPipeline):
Usage:
```{.py3 title="Example Usage"}
from kotaemon.llms import AzureChatOpenAI, BasePromptComponent
from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent
from kotaemon.parsers import RegexExtractor
def identity(x):
return x
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base="your openai api base",
openai_api_key="your openai api key",
openai_api_version="your openai api version",

View File

@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"]
# metadata and dependencies
[project]
name = "kotaemon"
version = "0.3.8"
version = "0.3.9"
requires-python = ">= 3.10"
description = "Kotaemon core library for AI development."
dependencies = [

View File

@ -13,7 +13,7 @@ from kotaemon.agents import (
RewooAgent,
WikipediaTool,
)
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!"
REWOO_VALID_PLAN = (
@ -112,7 +112,7 @@ _openai_chat_completion_responses_react_langchain_tool = [
@pytest.fixture
def llm():
return AzureChatOpenAI(
return LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",

View File

@ -4,10 +4,10 @@ import pytest
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms import (
AzureChatOpenAI,
BasePromptComponent,
GatedBranchingPipeline,
GatedLinearPipeline,
LCAzureChatOpenAI,
SimpleBranchingPipeline,
SimpleLinearPipeline,
)
@ -40,7 +40,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
@pytest.fixture
def mock_llm():
return AzureChatOpenAI(
return LCAzureChatOpenAI(
azure_endpoint="OPENAI_API_BASE",
openai_api_key="OPENAI_API_KEY",
openai_api_version="OPENAI_API_VERSION",

View File

@ -2,7 +2,7 @@ from unittest.mock import patch
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought
_openai_chat_completion_response = [
@ -38,7 +38,7 @@ _openai_chat_completion_response = [
side_effect=_openai_chat_completion_response,
)
def test_cot_plus_operator(openai_completion):
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
@ -70,7 +70,7 @@ def test_cot_plus_operator(openai_completion):
side_effect=_openai_chat_completion_response,
)
def test_cot_manual(openai_completion):
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",
@ -100,7 +100,7 @@ def test_cot_manual(openai_completion):
side_effect=_openai_chat_completion_response,
)
def test_cot_with_termination_callback(openai_completion):
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",

View File

@ -4,7 +4,7 @@ from unittest.mock import patch
import pytest
from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage
from kotaemon.llms import AzureChatOpenAI, LlamaCppChat
from kotaemon.llms import LCAzureChatOpenAI, LlamaCppChat
try:
from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC
@ -43,7 +43,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj(
side_effect=lambda *args, **kwargs: _openai_chat_completion_response,
)
def test_azureopenai_model(openai_completion):
model = AzureChatOpenAI(
model = LCAzureChatOpenAI(
azure_endpoint="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",

View File

@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.base import Document
from kotaemon.indices.rankings import LLMReranking
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
_openai_chat_completion_responses = [
ChatCompletion.parse_obj(
@ -41,7 +41,7 @@ _openai_chat_completion_responses = [
@pytest.fixture
def llm():
return AzureChatOpenAI(
return LCAzureChatOpenAI(
azure_endpoint="https://dummy.openai.azure.com/",
openai_api_key="dummy",
openai_api_version="2023-03-15-preview",

View File

@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
):
if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""):
KH_LLMS["azure"] = {
"def": {
"spec": {
"__type__": "kotaemon.llms.AzureChatOpenAI",
"temperature": 0,
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_key": config("AZURE_OPENAI_API_KEY", default=""),
"api_version": config("OPENAI_API_VERSION", default="")
or "2024-02-15-preview",
"deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
"request_timeout": 10,
"stream": False,
"azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""),
"timeout": 20,
},
"default": False,
"accuracy": 5,
@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config(
}
if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""):
KH_EMBEDDINGS["azure"] = {
"def": {
"spec": {
"__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings",
"azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""),
"openai_api_key": config("AZURE_OPENAI_API_KEY", default=""),
@ -164,5 +163,11 @@ KH_INDICES = [
"name": "File",
"config": {},
"index_type": "ktem.index.file.FileIndex",
}
},
{
"id": 2,
"name": "Sample",
"config": {},
"index_type": "ktem.index.file.FileIndex",
},
]

View File

@ -3,6 +3,7 @@
import logging
from functools import cache
from pathlib import Path
from typing import Optional
from theflow.settings import settings
from theflow.utils.modules import deserialize
@ -48,7 +49,7 @@ class ModelPool:
self._default: list[str] = []
for name, model in conf.items():
self._models[name] = deserialize(model["def"], safe=False)
self._models[name] = deserialize(model["spec"], safe=False)
if model.get("default", False):
self._default.append(name)
@ -58,11 +59,27 @@ class ModelPool:
self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf"))))
def __getitem__(self, key: str) -> BaseComponent:
"""Get model by name"""
return self._models[key]
def __setitem__(self, key: str, value: BaseComponent):
"""Set model by name"""
self._models[key] = value
def __delitem__(self, key: str):
"""Delete model by name"""
del self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS)
embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS)
reasonings: dict = {}
tools = ModelPool("Tools", {})
indices = ModelPool("Indices", {})

View File

@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever):
@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms
from ktem.llms.manager import llms
try:
reranking_llm = llms.get_lowest_cost_name()
reranking_llm = llms.get_default_name()
reranking_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)

View File

36
libs/ktem/ktem/llms/db.py Normal file
View File

@ -0,0 +1,36 @@
from typing import Type
from ktem.db.engine import engine
from sqlalchemy import JSON, Boolean, Column, String
from sqlalchemy.orm import DeclarativeBase
from theflow.settings import settings as flowsettings
from theflow.utils.modules import import_dotted_string
class Base(DeclarativeBase):
pass
class BaseLLMTable(Base):
"""Base table to store language model"""
__abstract__ = True
name = Column(String, primary_key=True, unique=True)
spec = Column(JSON, default={})
default = Column(Boolean, default=False)
_base_llm: Type[BaseLLMTable] = (
import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False)
if hasattr(flowsettings, "KH_TABLE_LLM")
else BaseLLMTable
)
class LLMTable(_base_llm): # type: ignore
__tablename__ = "llm_table"
if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False):
LLMTable.metadata.create_all(engine)

View File

@ -0,0 +1,191 @@
from typing import Optional, Type
from sqlalchemy import select
from sqlalchemy.orm import Session
from theflow.settings import settings as flowsettings
from theflow.utils.modules import deserialize
from kotaemon.base import BaseComponent
from .db import LLMTable, engine
class LLMManager:
"""Represent a pool of models"""
def __init__(self):
self._models: dict[str, BaseComponent] = {}
self._info: dict[str, dict] = {}
self._default: str = ""
self._vendors: list[Type] = []
if hasattr(flowsettings, "KH_LLMS"):
for name, model in flowsettings.KH_LLMS.items():
with Session(engine) as session:
stmt = select(LLMTable).where(LLMTable.name == name)
result = session.execute(stmt)
if not result.first():
item = LLMTable(
name=name,
spec=model["spec"],
default=model.get("default", False),
)
session.add(item)
session.commit()
self.load()
self.load_vendors()
def load(self):
"""Load the model pool from database"""
self._models, self._info, self._defaut = {}, {}, ""
with Session(engine) as session:
stmt = select(LLMTable)
items = session.execute(stmt)
for (item,) in items:
self._models[item.name] = deserialize(item.spec, safe=False)
self._info[item.name] = {
"name": item.name,
"spec": item.spec,
"default": item.default,
}
if item.default:
self._default = item.name
def load_vendors(self):
from kotaemon.llms import (
AzureChatOpenAI,
ChatOpenAI,
EndpointChatLLM,
LlamaCppChat,
)
self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM]
def __getitem__(self, key: str) -> BaseComponent:
"""Get model by name"""
return self._models[key]
def __contains__(self, key: str) -> bool:
"""Check if model exists"""
return key in self._models
def get(
self, key: str, default: Optional[BaseComponent] = None
) -> Optional[BaseComponent]:
"""Get model by name with default value"""
return self._models.get(key, default)
def settings(self) -> dict:
"""Present model pools option for gradio"""
return {
"label": "LLM",
"choices": list(self._models.keys()),
"value": self.get_default_name(),
}
def options(self) -> dict:
"""Present a dict of models"""
return self._models
def get_random_name(self) -> str:
"""Get the name of random model
Returns:
str: random model name in the pool
"""
import random
if not self._models:
raise ValueError("No models in pool")
return random.choice(list(self._models.keys()))
def get_default_name(self) -> str:
"""Get the name of default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
str: model name
"""
if not self._models:
raise ValueError("No models in pool")
if not self._default:
return self.get_random_name()
return self._default
def get_random(self) -> BaseComponent:
"""Get random model"""
return self._models[self.get_random_name()]
def get_default(self) -> BaseComponent:
"""Get default model
In case there is no default model, choose random model from pool. In
case there are multiple default models, choose random from them.
Returns:
BaseComponent: model
"""
return self._models[self.get_default_name()]
def info(self) -> dict:
"""List all models"""
return self._info
def add(self, name: str, spec: dict, default: bool):
"""Add a new model to the pool"""
try:
with Session(engine) as session:
item = LLMTable(name=name, spec=spec, default=default)
session.add(item)
session.commit()
except Exception as e:
raise ValueError(f"Failed to add model {name}: {e}")
self.load()
def delete(self, name: str):
"""Delete a model from the pool"""
try:
with Session(engine) as session:
item = session.query(LLMTable).filter_by(name=name).first()
session.delete(item)
session.commit()
except Exception as e:
raise ValueError(f"Failed to delete model {name}: {e}")
self.load()
def update(self, name: str, spec: dict, default: bool):
"""Update a model in the pool"""
try:
with Session(engine) as session:
if default:
# turn all models to non-default
session.query(LLMTable).update({"default": False})
session.commit()
item = session.query(LLMTable).filter_by(name=name).first()
if not item:
raise ValueError(f"Model {name} not found")
item.spec = spec
item.default = default
session.commit()
except Exception as e:
raise ValueError(f"Failed to update model {name}: {e}")
self.load()
def vendors(self) -> dict:
"""Return list of vendors"""
return {vendor.__qualname__: vendor for vendor in self._vendors}
llms = LLMManager()

318
libs/ktem/ktem/llms/ui.py Normal file
View File

@ -0,0 +1,318 @@
from copy import deepcopy
import gradio as gr
import pandas as pd
import yaml
from ktem.app import BasePage
from .manager import llms
def format_description(cls):
params = cls.describe()["params"]
params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"]
for key, value in params.items():
if isinstance(value["auto_callback"], str):
continue
params_lines.append(f"| {key} | {value['type']} | {value['help']} |")
return f"{cls.__doc__}\n\n" + "\n".join(params_lines)
class LLMManagement(BasePage):
def __init__(self, app):
self._app = app
self.spec_desc_default = (
"# Spec description\n\nSelect an LLM to view the spec description."
)
self.on_building_ui()
def on_building_ui(self):
with gr.Tab(label="View"):
self.llm_list = gr.DataFrame(
headers=["name", "vendor", "default"],
interactive=False,
)
with gr.Column(visible=False) as self._selected_panel:
self.selected_llm_name = gr.Textbox(value="", visible=False)
with gr.Row():
with gr.Column():
self.edit_default = gr.Checkbox(
label="Set default",
info=(
"Set this LLM as default. If no default is set, a "
"random LLM will be used."
),
)
self.edit_spec = gr.Textbox(
label="Specification",
info="Specification of the LLM in YAML format",
lines=10,
)
with gr.Row(visible=False) as self._selected_panel_btn:
with gr.Column():
self.btn_edit_save = gr.Button("Save", min_width=10)
with gr.Column():
self.btn_delete = gr.Button("Delete", min_width=10)
with gr.Row():
self.btn_delete_yes = gr.Button(
"Confirm delete",
variant="primary",
visible=False,
min_width=10,
)
self.btn_delete_no = gr.Button(
"Cancel", visible=False, min_width=10
)
with gr.Column():
self.btn_close = gr.Button("Close", min_width=10)
with gr.Column():
self.edit_spec_desc = gr.Markdown("# Spec description")
with gr.Tab(label="Add"):
with gr.Row():
with gr.Column(scale=2):
self.name = gr.Textbox(
label="LLM name",
info=(
"Must be unique. The name will be used to identify the LLM."
),
)
self.llm_choices = gr.Dropdown(
label="LLM vendors",
info=(
"Choose the vendor for the LLM. Each vendor has different "
"specification."
),
)
self.spec = gr.Textbox(
label="Specification",
info="Specification of the LLM in YAML format",
)
self.default = gr.Checkbox(
label="Set default",
info=(
"Set this LLM as default. This default LLM will be used "
"by default across the application."
),
)
self.btn_new = gr.Button("Create LLM")
with gr.Column(scale=3):
self.spec_desc = gr.Markdown(self.spec_desc_default)
def _on_app_created(self):
"""Called when the app is created"""
self._app.app.load(
self.list_llms,
inputs=None,
outputs=[self.llm_list],
)
self._app.app.load(
lambda: gr.update(choices=list(llms.vendors().keys())),
outputs=[self.llm_choices],
)
def on_llm_vendor_change(self, vendor):
vendor = llms.vendors()[vendor]
required: dict = {}
desc = vendor.describe()
for key, value in desc["params"].items():
if value.get("required", False):
required[key] = None
return yaml.dump(required), format_description(vendor)
def on_register_events(self):
self.llm_choices.select(
self.on_llm_vendor_change,
inputs=[self.llm_choices],
outputs=[self.spec, self.spec_desc],
)
self.btn_new.click(
self.create_llm,
inputs=[self.name, self.llm_choices, self.spec, self.default],
outputs=None,
).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then(
lambda: ("", None, "", False, self.spec_desc_default),
outputs=[
self.name,
self.llm_choices,
self.spec,
self.default,
self.spec_desc,
],
)
self.llm_list.select(
self.select_llm,
inputs=self.llm_list,
outputs=[self.selected_llm_name],
show_progress="hidden",
)
self.selected_llm_name.change(
self.on_selected_llm_change,
inputs=[self.selected_llm_name],
outputs=[
self._selected_panel,
self._selected_panel_btn,
# delete section
self.btn_delete,
self.btn_delete_yes,
self.btn_delete_no,
# edit section
self.edit_spec,
self.edit_spec_desc,
self.edit_default,
],
show_progress="hidden",
)
self.btn_delete.click(
self.on_btn_delete_click,
inputs=None,
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_delete_yes.click(
self.delete_llm,
inputs=[self.selected_llm_name],
outputs=[self.selected_llm_name],
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
outputs=[self.llm_list],
)
self.btn_delete_no.click(
lambda: (
gr.update(visible=True),
gr.update(visible=False),
gr.update(visible=False),
),
inputs=None,
outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no],
show_progress="hidden",
)
self.btn_edit_save.click(
self.save_llm,
inputs=[
self.selected_llm_name,
self.edit_default,
self.edit_spec,
],
show_progress="hidden",
).then(
self.list_llms,
inputs=None,
outputs=[self.llm_list],
)
self.btn_close.click(
lambda: "",
outputs=[self.selected_llm_name],
)
def create_llm(self, name, choices, spec, default):
try:
spec = yaml.safe_load(spec)
spec["__type__"] = (
llms.vendors()[choices].__module__
+ "."
+ llms.vendors()[choices].__qualname__
)
llms.add(name, spec=spec, default=default)
gr.Info(f"LLM {name} created successfully")
except Exception as e:
gr.Error(f"Failed to create LLM {name}: {e}")
def list_llms(self):
"""List the LLMs"""
items = []
for item in llms.info().values():
record = {}
record["name"] = item["name"]
record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1]
record["default"] = item["default"]
items.append(record)
if items:
llm_list = pd.DataFrame.from_records(items)
else:
llm_list = pd.DataFrame.from_records(
[{"name": "-", "vendor": "-", "default": "-"}]
)
return llm_list
def select_llm(self, llm_list, ev: gr.SelectData):
if ev.value == "-" and ev.index[0] == 0:
gr.Info("No LLM is loaded. Please add LLM first")
return ""
if not ev.selected:
return ""
return llm_list["name"][ev.index[0]]
def on_selected_llm_change(self, selected_llm_name):
if selected_llm_name == "":
_selected_panel = gr.update(visible=False)
_selected_panel_btn = gr.update(visible=False)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
edit_spec = gr.update(value="")
edit_spec_desc = gr.update(value="")
edit_default = gr.update(value=False)
else:
_selected_panel = gr.update(visible=True)
_selected_panel_btn = gr.update(visible=True)
btn_delete = gr.update(visible=True)
btn_delete_yes = gr.update(visible=False)
btn_delete_no = gr.update(visible=False)
info = deepcopy(llms.info()[selected_llm_name])
vendor_str = info["spec"].pop("__type__", "-").split(".")[-1]
vendor = llms.vendors()[vendor_str]
edit_spec = yaml.dump(info["spec"])
edit_spec_desc = format_description(vendor)
edit_default = info["default"]
return (
_selected_panel,
_selected_panel_btn,
btn_delete,
btn_delete_yes,
btn_delete_no,
edit_spec,
edit_spec_desc,
edit_default,
)
def on_btn_delete_click(self):
btn_delete = gr.update(visible=False)
btn_delete_yes = gr.update(visible=True)
btn_delete_no = gr.update(visible=True)
return btn_delete, btn_delete_yes, btn_delete_no
def save_llm(self, selected_llm_name, default, spec):
try:
spec = yaml.safe_load(spec)
spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"]
llms.update(selected_llm_name, spec=spec, default=default)
gr.Info(f"LLM {selected_llm_name} saved successfully")
except Exception as e:
gr.Error(f"Failed to save LLM {selected_llm_name}: {e}")
def delete_llm(self, selected_llm_name):
try:
llms.delete(selected_llm_name)
except Exception as e:
gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}")
return selected_llm_name
return ""

View File

@ -1,6 +1,7 @@
import gradio as gr
from ktem.app import BasePage
from ktem.db.models import User, engine
from ktem.llms.ui import LLMManagement
from sqlmodel import Session, select
from .user import UserManagement
@ -16,6 +17,9 @@ class AdminPage(BasePage):
with gr.Tab("User Management", visible=False) as self.user_management_tab:
self.user_management = UserManagement(self._app)
with gr.Tab("LLM Management") as self.llm_management_tab:
self.llm_management = LLMManagement(self._app)
def on_subscribe_public_events(self):
if self._app.f_user_management:
self._app.subscribe_event(

View File

@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine
from sqlmodel import Session, select
from theflow.settings import settings as flowsettings
from kotaemon.base import Document
from .chat_panel import ChatPanel
from .chat_suggestion import ChatSuggestion
from .common import STATE
@ -189,6 +191,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
@ -220,6 +223,7 @@ class ChatPage(BasePage):
self.chat_control.conversation_rn,
self.chat_panel.chatbot,
self.info_panel,
self.chat_state,
]
+ self._indices_input,
show_progress="hidden",
@ -392,7 +396,7 @@ class ChatPage(BasePage):
return pipeline, reasoning_state
async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Chat function"""
chat_input = chat_history[-1][0]
chat_history = chat_history[:-1]
@ -403,52 +407,43 @@ class ChatPage(BasePage):
pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds)
pipeline.set_output_queue(queue)
asyncio.create_task(pipeline(chat_input, conversation_id, chat_history))
text, refs = "", ""
len_ref = -1 # for logging purpose
msg_placeholder = getattr(
flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..."
)
print(msg_placeholder)
while True:
try:
response = queue.get_nowait()
except Exception:
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [
(chat_input, text or msg_placeholder)
], refs, state
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
len_ref = -1 # for logging purpose
for response in pipeline.stream(chat_input, conversation_id, chat_history):
if not isinstance(response, Document):
continue
if response is None:
queue.task_done()
print("Chat completed")
break
if response.channel is None:
continue
if "output" in response:
if response["output"] is None:
if response.channel == "chat":
if response.content is None:
text = ""
else:
text += response["output"]
text += response.content
if "evidence" in response:
if response["evidence"] is None:
if response.channel == "info":
if response.content is None:
refs = ""
else:
refs += response["evidence"]
refs += response.content
if len(refs) > len_ref:
print(f"Len refs: {len(refs)}")
len_ref = len(refs)
state[pipeline.get_info()["id"]] = reasoning_state["pipeline"]
yield chat_history + [(chat_input, text)], refs, state
yield chat_history + [(chat_input, text or msg_placeholder)], refs, state
async def regen_fn(
self, conversation_id, chat_history, settings, state, *selecteds
):
def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds):
"""Regen function"""
if not chat_history:
gr.Warning("Empty chat")
@ -456,12 +451,11 @@ class ChatPage(BasePage):
return
state["app"]["regen"] = True
async for chat, refs, state in self.chat_fn(
for chat, refs, state in self.chat_fn(
conversation_id, chat_history, settings, state, *selecteds
):
new_state = deepcopy(state)
new_state["app"]["regen"] = False
yield chat, refs, new_state
else:
state["app"]["regen"] = False
yield chat_history, "", state

View File

@ -4,10 +4,10 @@ import logging
import re
from collections import defaultdict
from functools import partial
from typing import Generator
import tiktoken
from ktem.components import llms
from theflow.settings import settings as flowsettings
from ktem.llms.manager import llms
from kotaemon.base import (
BaseComponent,
@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent):
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy())
vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
vlm_endpoint: str = ""
citation_pipeline: CitationPipeline = Node(
default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost())
default_callback=lambda _: CitationPipeline(llm=llms.get_default())
)
qa_template: str = DEFAULT_QA_TEXT_PROMPT
@ -297,8 +297,90 @@ class AnswerWithContextPipeline(BaseComponent):
return answer
def stream( # type: ignore
self, question: str, evidence: str, evidence_mode: int = 0, **kwargs
) -> Generator[Document, None, Document]:
"""Answer the question based on the evidence
def extract_evidence_images(self, evidence: str):
In addition to the question and the evidence, this method also take into
account evidence_mode. The evidence_mode tells which kind of evidence is.
The kind of evidence affects:
1. How the evidence is represented.
2. The prompt to generate the answer.
By default, the evidence_mode is 0, which means the evidence is plain text with
no particular semantic representation. The evidence_mode can be:
1. "table": There will be HTML markup telling that there is a table
within the evidence.
2. "chatbot": There will be HTML markup telling that there is a chatbot.
This chatbot is a scenario, extracted from an Excel file, where each
row corresponds to an interaction.
Args:
question: the original question posed by user
evidence: the text that contain relevant information to answer the question
(determined by retrieval pipeline)
evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot
"""
if evidence_mode == EVIDENCE_MODE_TEXT:
prompt_template = PromptTemplate(self.qa_template)
elif evidence_mode == EVIDENCE_MODE_TABLE:
prompt_template = PromptTemplate(self.qa_table_template)
elif evidence_mode == EVIDENCE_MODE_FIGURE:
prompt_template = PromptTemplate(self.qa_figure_template)
else:
prompt_template = PromptTemplate(self.qa_chatbot_template)
images = []
if evidence_mode == EVIDENCE_MODE_FIGURE:
# isolate image from evidence
evidence, images = self.extract_evidence_images(evidence)
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
else:
prompt = prompt_template.populate(
context=evidence,
question=question,
lang=self.lang,
)
output = ""
if evidence_mode == EVIDENCE_MODE_FIGURE:
for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768):
output += text
yield Document(channel="chat", content=text)
else:
messages = []
if self.system_prompt:
messages.append(SystemMessage(content=self.system_prompt))
messages.append(HumanMessage(content=prompt))
try:
# try streaming first
print("Trying LLM streaming")
for text in self.llm.stream(messages):
output += text.text
yield Document(channel="chat", content=text.text)
except NotImplementedError:
print("Streaming is not supported, falling back to normal processing")
output = self.llm(messages).text
yield Document(channel="chat", content=output)
# retrieve the citation
citation = None
if evidence and self.enable_citation:
citation = self.citation_pipeline.invoke(
context=evidence, question=question
)
answer = Document(text=output, metadata={"citation": citation})
return answer
def extract_evidence_images(self, evidence: str):
"""Util function to extract and isolate images from context/evidence"""
image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'"
matches = re.findall(image_pattern, evidence)
@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent):
lang: the language of the answer. Currently support English and Japanese
"""
llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost())
llm: ChatLLM = Node(default_callback=lambda _: llms.get_default())
rewrite_template: str = DEFAULT_REWRITE_PROMPT
lang: str = "English"
async def run(self, question: str) -> Document: # type: ignore
def run(self, question: str) -> Document: # type: ignore
prompt_template = PromptTemplate(self.rewrite_template)
prompt = prompt_template.populate(question=question, lang=self.lang)
messages = [
SystemMessage(content="You are a helpful assistant"),
HumanMessage(content=prompt),
]
output = ""
for text in self.llm(messages):
if "content" in text:
output += text[1]
self.report_output({"chat_input": text[1]})
break
await asyncio.sleep(0)
return Document(text=output)
return self.llm(messages)
class FullQAPipeline(BaseReasoning):
@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning):
rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx()
use_rewrite: bool = False
async def run( # type: ignore
async def ainvoke( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Document: # type: ignore
import markdown
@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning):
self.report_output(None)
return answer
def stream( # type: ignore
self, message: str, conv_id: str, history: list, **kwargs # type: ignore
) -> Generator[Document, None, Document]:
import markdown
docs = []
doc_ids = []
if self.use_rewrite:
message = self.rewrite_pipeline(question=message).text
for retriever in self.retrievers:
for doc in retriever(text=message):
if doc.doc_id not in doc_ids:
docs.append(doc)
doc_ids.append(doc.doc_id)
for doc in docs:
# TODO: a better approach to show the information
text = markdown.markdown(
doc.text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{doc.metadata['file_name']}</summary>"
f"{text}"
"</details><br>"
),
channel="info",
)
evidence_mode, evidence = self.evidence_pipeline(docs).content
answer = yield from self.answering_pipeline.stream(
question=message,
history=history,
evidence=evidence,
evidence_mode=evidence_mode,
conv_id=conv_id,
**kwargs,
)
# prepare citation
spans = defaultdict(list)
if answer.metadata["citation"] is not None:
for fact_with_evidence in answer.metadata["citation"].answer:
for quote in fact_with_evidence.substring_quote:
for doc in docs:
start_idx = doc.text.find(quote)
if start_idx == -1:
continue
end_idx = start_idx + len(quote)
current_idx = start_idx
if "|" not in doc.text[start_idx:end_idx]:
spans[doc.doc_id].append(
{"start": start_idx, "end": end_idx}
)
else:
while doc.text[current_idx:end_idx].find("|") != -1:
match_idx = doc.text[current_idx:end_idx].find("|")
spans[doc.doc_id].append(
{
"start": current_idx,
"end": current_idx + match_idx,
}
)
current_idx += match_idx + 2
if current_idx > end_idx:
break
break
id2docs = {doc.doc_id: doc for doc in docs}
lack_evidence = True
not_detected = set(id2docs.keys()) - set(spans.keys())
yield Document(channel="info", content=None)
for id, ss in spans.items():
if not ss:
not_detected.add(id)
continue
ss = sorted(ss, key=lambda x: x["start"])
text = id2docs[id].text[: ss[0]["start"]]
for idx, span in enumerate(ss):
text += (
"<mark>" + id2docs[id].text[span["start"] : span["end"]] + "</mark>"
)
if idx < len(ss) - 1:
text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]]
text += id2docs[id].text[ss[-1]["end"] :]
text_out = markdown.markdown(
text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details open>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
)
lack_evidence = False
if lack_evidence:
yield Document(channel="info", content="No evidence found.\n")
if not_detected:
yield Document(
channel="info",
content="Retrieved segments without matching evidence:\n",
)
for id in list(not_detected):
text_out = markdown.markdown(
id2docs[id].text, extensions=["markdown.extensions.tables"]
)
yield Document(
content=(
"<details>"
f"<summary>{id2docs[id].metadata['file_name']}</summary>"
f"{text_out}"
"</details><br>"
),
channel="info",
)
return answer
@classmethod
def get_pipeline(cls, settings, states, retrievers):
"""Get the reasoning pipeline
@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning):
_id = cls.get_info()["id"]
pipeline = FullQAPipeline(retrievers=retrievers)
pipeline.answering_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.main_llm"]
]
pipeline.answering_pipeline.citation_pipeline.llm = llms[
settings[f"reasoning.options.{_id}.citation_llm"]
]
pipeline.answering_pipeline.llm = llms.get_default()
pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default()
pipeline.answering_pipeline.enable_citation = settings[
f"reasoning.options.{_id}.highlight_citation"
]
@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning):
f"reasoning.options.{_id}.qa_prompt"
]
pipeline.use_rewrite = states.get("app", {}).get("regen", False)
pipeline.rewrite_pipeline.llm = llms.get_lowest_cost()
pipeline.rewrite_pipeline.llm = llms.get_default()
pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get(
settings["reasoning.lang"], "English"
)
@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning):
@classmethod
def get_user_settings(cls) -> dict:
from ktem.components import llms
try:
citation_llm = llms.get_lowest_cost_name()
citation_llm_choices = list(llms.options().keys())
main_llm = llms.get_highest_accuracy_name()
main_llm_choices = list(llms.options().keys())
except Exception as e:
logger.error(e)
citation_llm = None
citation_llm_choices = []
main_llm = None
main_llm_choices = []
return {
"highlight_citation": {
"name": "Highlight Citation",
"value": False,
"component": "checkbox",
},
"citation_llm": {
"name": "LLM for citation",
"value": citation_llm,
"component": "dropdown",
"choices": citation_llm_choices,
},
"main_llm": {
"name": "LLM for main generation",
"value": main_llm,
"component": "dropdown",
"choices": main_llm_choices,
},
"system_prompt": {
"name": "System Prompt",
"value": "This is a question answering system",

View File

@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline
from openai.resources.embeddings import Embeddings
from openai.types.chat.chat_completion import ChatCompletion
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f:
openai_embedding = json.load(f)
@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path):
assert len(results) == 1
# create llm
llm = AzureChatOpenAI(
llm = LCAzureChatOpenAI(
openai_api_base="https://test.openai.azure.com/",
openai_api_key="some-key",
openai_api_version="2023-03-15-preview",

View File

@ -2,4 +2,4 @@ from ktem.main import App
app = App()
demo = app.make()
demo.queue().launch(favicon_path=app._favicon, inbrowser=True)
demo.queue().launch(favicon_path=app._favicon)

View File

@ -5,7 +5,7 @@ from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, la
from kotaemon.contribs.promptui.logs import ResultLog
from kotaemon.embeddings import AzureOpenAIEmbeddings
from kotaemon.indices import VectorIndexing, VectorRetrieval
from kotaemon.llms import AzureChatOpenAI
from kotaemon.llms import LCAzureChatOpenAI
from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore
@ -34,7 +34,7 @@ class QuestionAnsweringPipeline(BaseComponent):
]
retrieval_top_k: int = 1
llm: AzureChatOpenAI = AzureChatOpenAI.withx(
llm: LCAzureChatOpenAI = LCAzureChatOpenAI.withx(
azure_endpoint="https://bleh-dummy-2.openai.azure.com/",
openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"),
openai_api_version="2023-03-15-preview",