diff --git a/.gitignore b/.gitignore index 5c91c3e..711a741 100644 --- a/.gitignore +++ b/.gitignore @@ -458,6 +458,7 @@ logs/ .gitsecret/keys/random_seed !*.secret .envrc +.env S.gpg-agent* .vscode/settings.json diff --git a/docs/development/create-a-component.md b/docs/development/create-a-component.md index 029bbe9..f831259 100644 --- a/docs/development/create-a-component.md +++ b/docs/development/create-a-component.md @@ -22,7 +22,7 @@ The syntax of a component is as follow: ```python from kotaemon.base import BaseComponent -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI from kotaemon.parsers import RegexExtractor @@ -32,7 +32,7 @@ class FancyPipeline(BaseComponent): param3: float node1: BaseComponent # this is a node because of BaseComponent type annotation - node2: AzureChatOpenAI # this is also a node because AzureChatOpenAI subclasses BaseComponent + node2: LCAzureChatOpenAI # this is also a node because LCAzureChatOpenAI subclasses BaseComponent node3: RegexExtractor # this is also a node bceause RegexExtractor subclasses BaseComponent def run(self, some_text: str): @@ -45,7 +45,7 @@ class FancyPipeline(BaseComponent): Then this component can be used as follow: ```python -llm = AzureChatOpenAI(endpoint="some-endpont") +llm = LCAzureChatOpenAI(endpoint="some-endpont") extractor = RegexExtractor(pattern=["yes", "Yes"]) component = FancyPipeline( diff --git a/docs/pages/app/customize-flows.md b/docs/pages/app/customize-flows.md index 1277e34..3dd005e 100644 --- a/docs/pages/app/customize-flows.md +++ b/docs/pages/app/customize-flows.md @@ -193,7 +193,8 @@ information panel. You can access users' collections of LLMs and embedding models with: ```python -from ktem.components import llms, embeddings +from ktem.components import embeddings +from ktem.llms.manager import llms llm = llms.get_default() @@ -206,12 +207,12 @@ models they want to use through the settings. ```python @classmethod def get_user_settings(cls) -> dict: - from ktem.components import llms + from ktem.llms.manager import llms return { "citation_llm": { "name": "LLM for citation", - "value": llms.get_lowest_cost_name(), + "value": llms.get_default(), "component: "dropdown", "choices": list(llms.options().keys()), }, diff --git a/libs/kotaemon/kotaemon/base/component.py b/libs/kotaemon/kotaemon/base/component.py index 9acd39f..6936b2a 100644 --- a/libs/kotaemon/kotaemon/base/component.py +++ b/libs/kotaemon/kotaemon/base/component.py @@ -52,7 +52,7 @@ class BaseComponent(Function): def stream(self, *args, **kwargs) -> Iterator[Document] | None: ... - async def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None: + def astream(self, *args, **kwargs) -> AsyncGenerator[Document, None] | None: ... @abstractmethod diff --git a/libs/kotaemon/kotaemon/base/schema.py b/libs/kotaemon/kotaemon/base/schema.py index 1d0e622..07fe9f5 100644 --- a/libs/kotaemon/kotaemon/base/schema.py +++ b/libs/kotaemon/kotaemon/base/schema.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Optional, TypeVar +from typing import TYPE_CHECKING, Any, Literal, Optional, TypeVar from langchain.schema.messages import AIMessage as LCAIMessage from langchain.schema.messages import HumanMessage as LCHumanMessage @@ -10,6 +10,9 @@ from llama_index.schema import Document as BaseDocument if TYPE_CHECKING: from haystack.schema import Document as HaystackDocument + from openai.types.chat.chat_completion_message_param import ( + ChatCompletionMessageParam, + ) IO_Type = TypeVar("IO_Type", "Document", str) SAMPLE_TEXT = "A sample Document from kotaemon" @@ -26,10 +29,15 @@ class Document(BaseDocument): Attributes: content: raw content of the document, can be anything source: id of the source of the Document. Optional. + channel: the channel to show the document. Optional.: + - chat: show in chat message + - info: show in information panel + - debug: show in debug panel """ - content: Any + content: Any = None source: Optional[str] = None + channel: Optional[Literal["chat", "info", "debug"]] = None def __init__(self, content: Optional[Any] = None, *args, **kwargs): if content is None: @@ -87,17 +95,23 @@ class BaseMessage(Document): def __add__(self, other: Any): raise NotImplementedError + def to_openai_format(self) -> "ChatCompletionMessageParam": + raise NotImplementedError + class SystemMessage(BaseMessage, LCSystemMessage): - pass + def to_openai_format(self) -> "ChatCompletionMessageParam": + return {"role": "system", "content": self.content} class AIMessage(BaseMessage, LCAIMessage): - pass + def to_openai_format(self) -> "ChatCompletionMessageParam": + return {"role": "assistant", "content": self.content} class HumanMessage(BaseMessage, LCHumanMessage): - pass + def to_openai_format(self) -> "ChatCompletionMessageParam": + return {"role": "user", "content": self.content} class RetrievedDocument(Document): diff --git a/libs/kotaemon/kotaemon/indices/qa/text_based.py b/libs/kotaemon/kotaemon/indices/qa/text_based.py index 5b1f6e3..e0b49be 100644 --- a/libs/kotaemon/kotaemon/indices/qa/text_based.py +++ b/libs/kotaemon/kotaemon/indices/qa/text_based.py @@ -1,7 +1,7 @@ import os from kotaemon.base import BaseComponent, Document, Node, RetrievedDocument -from kotaemon.llms import AzureChatOpenAI, BaseLLM, PromptTemplate +from kotaemon.llms import BaseLLM, LCAzureChatOpenAI, PromptTemplate from .citation import CitationPipeline @@ -13,7 +13,7 @@ class CitationQAPipeline(BaseComponent): 'Answer the following question: "{question}". ' "The context is: \n{context}\nAnswer: " ) - llm: BaseLLM = AzureChatOpenAI.withx( + llm: BaseLLM = LCAzureChatOpenAI.withx( azure_endpoint="https://bleh-dummy.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", ""), openai_api_version="2023-07-01-preview", diff --git a/libs/kotaemon/kotaemon/llms/__init__.py b/libs/kotaemon/kotaemon/llms/__init__.py index d7547a6..266e391 100644 --- a/libs/kotaemon/kotaemon/llms/__init__.py +++ b/libs/kotaemon/kotaemon/llms/__init__.py @@ -2,7 +2,15 @@ from kotaemon.base.schema import AIMessage, BaseMessage, HumanMessage, SystemMes from .base import BaseLLM from .branching import GatedBranchingPipeline, SimpleBranchingPipeline -from .chats import AzureChatOpenAI, ChatLLM, ChatOpenAI, EndpointChatLLM, LlamaCppChat +from .chats import ( + AzureChatOpenAI, + ChatLLM, + ChatOpenAI, + EndpointChatLLM, + LCAzureChatOpenAI, + LCChatOpenAI, + LlamaCppChat, +) from .completions import LLM, AzureOpenAI, LlamaCpp, OpenAI from .cot import ManualSequentialChainOfThought, Thought from .linear import GatedLinearPipeline, SimpleLinearPipeline @@ -17,8 +25,10 @@ __all__ = [ "HumanMessage", "AIMessage", "SystemMessage", - "ChatOpenAI", "AzureChatOpenAI", + "ChatOpenAI", + "LCAzureChatOpenAI", + "LCChatOpenAI", "LlamaCppChat", # completion-specific components "LLM", diff --git a/libs/kotaemon/kotaemon/llms/base.py b/libs/kotaemon/kotaemon/llms/base.py index 6ef7afc..374d139 100644 --- a/libs/kotaemon/kotaemon/llms/base.py +++ b/libs/kotaemon/kotaemon/llms/base.py @@ -18,5 +18,8 @@ class BaseLLM(BaseComponent): def stream(self, *args, **kwargs) -> Iterator[LLMInterface]: raise NotImplementedError - async def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]: + def astream(self, *args, **kwargs) -> AsyncGenerator[LLMInterface, None]: raise NotImplementedError + + def run(self, *args, **kwargs): + return self.invoke(*args, **kwargs) diff --git a/libs/kotaemon/kotaemon/llms/branching.py b/libs/kotaemon/kotaemon/llms/branching.py index a9cbbe8..ee49dc5 100644 --- a/libs/kotaemon/kotaemon/llms/branching.py +++ b/libs/kotaemon/kotaemon/llms/branching.py @@ -15,7 +15,7 @@ class SimpleBranchingPipeline(BaseComponent): Example: ```python from kotaemon.llms import ( - AzureChatOpenAI, + LCAzureChatOpenAI, BasePromptComponent, GatedLinearPipeline, ) @@ -25,7 +25,7 @@ class SimpleBranchingPipeline(BaseComponent): return x pipeline = SimpleBranchingPipeline() - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", @@ -92,7 +92,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): Example: ```python from kotaemon.llms import ( - AzureChatOpenAI, + LCAzureChatOpenAI, BasePromptComponent, GatedLinearPipeline, ) @@ -102,7 +102,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): return x pipeline = GatedBranchingPipeline() - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", @@ -157,7 +157,7 @@ class GatedBranchingPipeline(SimpleBranchingPipeline): if __name__ == "__main__": import dotenv - from kotaemon.llms import AzureChatOpenAI, BasePromptComponent + from kotaemon.llms import BasePromptComponent, LCAzureChatOpenAI from kotaemon.parsers import RegexExtractor def identity(x): @@ -166,7 +166,7 @@ if __name__ == "__main__": secrets = dotenv.dotenv_values(".env") pipeline = GatedBranchingPipeline() - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base=secrets.get("OPENAI_API_BASE", ""), openai_api_key=secrets.get("OPENAI_API_KEY", ""), openai_api_version=secrets.get("OPENAI_API_VERSION", ""), diff --git a/libs/kotaemon/kotaemon/llms/chats/__init__.py b/libs/kotaemon/kotaemon/llms/chats/__init__.py index 5b50317..7fc1c40 100644 --- a/libs/kotaemon/kotaemon/llms/chats/__init__.py +++ b/libs/kotaemon/kotaemon/llms/chats/__init__.py @@ -1,13 +1,17 @@ from .base import ChatLLM from .endpoint_based import EndpointChatLLM -from .langchain_based import AzureChatOpenAI, ChatOpenAI, LCChatMixin +from .langchain_based import LCAzureChatOpenAI, LCChatMixin, LCChatOpenAI from .llamacpp import LlamaCppChat +from .openai import AzureChatOpenAI, ChatOpenAI __all__ = [ + "ChatOpenAI", + "AzureChatOpenAI", "ChatLLM", "EndpointChatLLM", "ChatOpenAI", - "AzureChatOpenAI", + "LCChatOpenAI", + "LCAzureChatOpenAI", "LCChatMixin", "LlamaCppChat", ] diff --git a/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py b/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py index 170ec8b..5ab1835 100644 --- a/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/endpoint_based.py @@ -5,6 +5,7 @@ from kotaemon.base import ( BaseMessage, HumanMessage, LLMInterface, + Param, SystemMessage, ) @@ -20,7 +21,9 @@ class EndpointChatLLM(ChatLLM): endpoint_url (str): The url of a OpenAI API compatible endpoint. """ - endpoint_url: str + endpoint_url: str = Param( + help="URL of the OpenAI API compatible endpoint", required=True + ) def run( self, messages: str | BaseMessage | list[BaseMessage], **kwargs diff --git a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py index 526eaf8..fca78dc 100644 --- a/libs/kotaemon/kotaemon/llms/chats/langchain_based.py +++ b/libs/kotaemon/kotaemon/llms/chats/langchain_based.py @@ -165,7 +165,7 @@ class LCChatMixin: raise ValueError(f"Invalid param {path}") -class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore +class LCChatOpenAI(LCChatMixin, ChatLLM): # type: ignore def __init__( self, openai_api_base: str | None = None, @@ -193,7 +193,7 @@ class ChatOpenAI(LCChatMixin, ChatLLM): # type: ignore return ChatOpenAI -class AzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore +class LCAzureChatOpenAI(LCChatMixin, ChatLLM): # type: ignore def __init__( self, azure_endpoint: str | None = None, diff --git a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py index 62ee0ea..7b8bee4 100644 --- a/libs/kotaemon/kotaemon/llms/chats/llamacpp.py +++ b/libs/kotaemon/kotaemon/llms/chats/llamacpp.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Iterator, Optional, cast from kotaemon.base import BaseMessage, HumanMessage, LLMInterface, Param @@ -12,13 +12,32 @@ if TYPE_CHECKING: class LlamaCppChat(ChatLLM): """Wrapper around the llama-cpp-python's Llama model""" - model_path: Optional[str] = None - chat_format: Optional[str] = None - lora_base: Optional[str] = None - n_ctx: int = 512 - n_gpu_layers: int = 0 - use_mmap: bool = True - vocab_only: bool = False + model_path: str = Param( + help="Path to the model file. This is required to load the model.", + required=True, + ) + chat_format: str = Param( + help=( + "Chat format to use. Please refer to llama_cpp.llama_chat_format for a " + "list of supported formats. If blank, the chat format will be auto-" + "inferred." + ), + required=True, + ) + lora_base: Optional[str] = Param(None, help="Path to the base Lora model") + n_ctx: Optional[int] = Param(512, help="Text context, 0 = from model") + n_gpu_layers: Optional[int] = Param( + 0, + help=("Number of layers to offload to GPU. If -1, all layers are offloaded"), + ) + use_mmap: Optional[bool] = Param( + True, + help=(), + ) + vocab_only: Optional[bool] = Param( + False, + help=("If True, only the vocabulary is loaded. This is useful for debugging."), + ) _role_mapper: dict[str, str] = { "human": "user", @@ -60,9 +79,9 @@ class LlamaCppChat(ChatLLM): vocab_only=self.vocab_only, ) - def run( - self, messages: str | BaseMessage | list[BaseMessage], **kwargs - ) -> LLMInterface: + def prepare_message( + self, messages: str | BaseMessage | list[BaseMessage] + ) -> list[dict]: input_: list[BaseMessage] = [] if isinstance(messages, str): @@ -72,11 +91,19 @@ class LlamaCppChat(ChatLLM): else: input_ = messages + output_ = [ + {"role": self._role_mapper[each.type], "content": each.content} + for each in input_ + ] + + return output_ + + def invoke( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> LLMInterface: + pred: "CCCR" = self.client_object.create_chat_completion( - messages=[ - {"role": self._role_mapper[each.type], "content": each.content} - for each in input_ - ], # type: ignore + messages=self.prepare_message(messages), stream=False, ) @@ -91,3 +118,19 @@ class LlamaCppChat(ChatLLM): total_tokens=pred["usage"]["total_tokens"], prompt_tokens=pred["usage"]["prompt_tokens"], ) + + def stream( + self, messages: str | BaseMessage | list[BaseMessage], **kwargs + ) -> Iterator[LLMInterface]: + pred = self.client_object.create_chat_completion( + messages=self.prepare_message(messages), + stream=True, + ) + for chunk in pred: + if not chunk["choices"]: + continue + + if "content" not in chunk["choices"][0]["delta"]: + continue + + yield LLMInterface(content=chunk["choices"][0]["delta"]["content"]) diff --git a/libs/kotaemon/kotaemon/llms/chats/openai.py b/libs/kotaemon/kotaemon/llms/chats/openai.py new file mode 100644 index 0000000..6f492c7 --- /dev/null +++ b/libs/kotaemon/kotaemon/llms/chats/openai.py @@ -0,0 +1,356 @@ +from typing import TYPE_CHECKING, AsyncGenerator, Iterator, Optional + +from theflow.utils.modules import import_dotted_string + +from kotaemon.base import AIMessage, BaseMessage, HumanMessage, LLMInterface, Param + +from .base import ChatLLM + +if TYPE_CHECKING: + from openai.types.chat.chat_completion_message_param import ( + ChatCompletionMessageParam, + ) + + +class BaseChatOpenAI(ChatLLM): + """Base interface for OpenAI chat model, using the openai library + + This class exposes the parameters in resources.Chat. To subclass this class: + + - Implement the `prepare_client` method to return the OpenAI client + - Implement the `openai_response` method to return the OpenAI response + - Implement the params relate to the OpenAI client + """ + + _dependencies = ["openai"] + _capabilities = ["chat", "text"] # consider as mixin + + api_key: str = Param(help="API key", required=True) + timeout: Optional[float] = Param(None, help="Timeout for the API request") + max_retries: Optional[int] = Param( + None, help="Maximum number of retries for the API request" + ) + + temperature: Optional[float] = Param( + None, + help=( + "Number between 0 and 2 that controls the randomness of the generated " + "tokens. Lower values make the model more deterministic, while higher " + "values make the model more random." + ), + ) + max_tokens: Optional[int] = Param( + None, + help=( + "Maximum number of tokens to generate. The total length of input tokens " + "and generated tokens is limited by the model's context length." + ), + ) + n: int = Param( + 1, + help=( + "Number of completions to generate. The API will generate n completion " + "for each prompt." + ), + ) + stop: Optional[str | list[str]] = Param( + None, + help=( + "Stop sequence. If a stop sequence is detected, generation will stop " + "at that point. If not specified, generation will continue until the " + "maximum token length is reached." + ), + ) + frequency_penalty: Optional[float] = Param( + None, + help=( + "Number between -2.0 and 2.0. Positive values penalize new tokens " + "based on their existing frequency in the text so far, decrearsing the " + "model's likelihood of repeating the same text." + ), + ) + presence_penalty: Optional[float] = Param( + None, + help=( + "Number between -2.0 and 2.0. Positive values penalize new tokens " + "based on their existing presence in the text so far, decrearsing the " + "model's likelihood of repeating the same text." + ), + ) + tool_choice: Optional[str] = Param( + None, + help=( + "Choice of tool to use for the completion. Available choices are: " + "auto, default." + ), + ) + tools: Optional[list[str]] = Param( + None, + help="List of tools to use for the completion.", + ) + logprobs: Optional[bool] = Param( + None, + help=( + "Include log probabilities on the logprobs most likely tokens, " + "as well as the chosen token." + ), + ) + logit_bias: Optional[dict] = Param( + None, + help=( + "Dictionary of logit bias values to add to the logits of the tokens " + "in the vocabulary." + ), + ) + top_logprobs: Optional[int] = Param( + None, + help=( + "An integer between 0 and 5 specifying the number of most likely tokens " + "to return at each token position, each with an associated log " + "probability. `logprobs` must also be set to `true` if this parameter " + "is used." + ), + ) + top_p: Optional[float] = Param( + None, + help=( + "An alternative to sampling with temperature, called nucleus sampling, " + "where the model considers the results of the token with top_p " + "probability mass. So 0.1 means that only the tokens comprising the " + "top 10% probability mass are considered." + ), + ) + + @Param.auto(depends_on=["max_retries"]) + def max_retries_(self): + if self.max_retries is None: + from openai._constants import DEFAULT_MAX_RETRIES + + return DEFAULT_MAX_RETRIES + return self.max_retries + + def prepare_message( + self, messages: str | BaseMessage | list[BaseMessage] + ) -> list["ChatCompletionMessageParam"]: + """Prepare the message into OpenAI format + + Returns: + list[dict]: List of messages in OpenAI format + """ + input_: list[BaseMessage] = [] + output_: list["ChatCompletionMessageParam"] = [] + + if isinstance(messages, str): + input_ = [HumanMessage(content=messages)] + elif isinstance(messages, BaseMessage): + input_ = [messages] + else: + input_ = messages + + for message in input_: + output_.append(message.to_openai_format()) + + return output_ + + def prepare_client(self, async_version: bool = False): + """Get the OpenAI client + + Args: + async_version (bool): Whether to get the async version of the client + """ + raise NotImplementedError + + def openai_response(self, client, **kwargs): + """Get the openai response""" + raise NotImplementedError + + def invoke( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> LLMInterface: + client = self.prepare_client(async_version=False) + input_messages = self.prepare_message(messages) + resp = self.openai_response( + client, messages=input_messages, stream=False, **kwargs + ).dict() + + output = LLMInterface( + candidates=[_["message"]["content"] for _ in resp["choices"]], + content=resp["choices"][0]["message"]["content"], + total_tokens=resp["usage"]["total_tokens"], + prompt_tokens=resp["usage"]["prompt_tokens"], + completion_tokens=resp["usage"]["completion_tokens"], + messages=[ + AIMessage(content=_["message"]["content"]) for _ in resp["choices"] + ], + ) + + return output + + async def ainvoke( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> LLMInterface: + client = self.prepare_client(async_version=True) + input_messages = self.prepare_message(messages) + resp = await self.openai_response( + client, messages=input_messages, stream=False, **kwargs + ).dict() + + output = LLMInterface( + candidates=[_["message"]["content"] for _ in resp["choices"]], + content=resp["choices"][0]["message"]["content"], + total_tokens=resp["usage"]["total_tokens"], + prompt_tokens=resp["usage"]["prompt_tokens"], + completion_tokens=resp["usage"]["completion_tokens"], + messages=[ + AIMessage(content=_["message"]["content"]) for _ in resp["choices"] + ], + ) + + return output + + def stream( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> Iterator[LLMInterface]: + client = self.prepare_client(async_version=False) + input_messages = self.prepare_message(messages) + resp = self.openai_response( + client, messages=input_messages, stream=True, **kwargs + ) + + for chunk in resp: + if not chunk.choices: + continue + if chunk.choices[0].delta.content is not None: + yield LLMInterface(content=chunk.choices[0].delta.content) + + async def astream( + self, messages: str | BaseMessage | list[BaseMessage], *args, **kwargs + ) -> AsyncGenerator[LLMInterface, None]: + client = self.prepare_client(async_version=True) + input_messages = self.prepare_message(messages) + resp = self.openai_response( + client, messages=input_messages, stream=True, **kwargs + ) + + async for chunk in resp: + if not chunk.choices: + continue + if chunk.choices[0].delta.content is not None: + yield LLMInterface(content=chunk.choices[0].delta.content) + + +class ChatOpenAI(BaseChatOpenAI): + """OpenAI chat model""" + + base_url: Optional[str] = Param(None, help="OpenAI base URL") + organization: Optional[str] = Param(None, help="OpenAI organization") + model: str = Param(help="OpenAI model", required=True) + + def prepare_client(self, async_version: bool = False): + """Get the OpenAI client + + Args: + async_version (bool): Whether to get the async version of the client + """ + params = { + "api_key": self.api_key, + "organization": self.organization, + "base_url": self.base_url, + "timeout": self.timeout, + "max_retries": self.max_retries_, + } + if async_version: + from openai import AsyncOpenAI + + return AsyncOpenAI(**params) + + from openai import OpenAI + + return OpenAI(**params) + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = { + "model": self.model, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": self.n, + "stop": self.stop, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "tool_choice": self.tool_choice, + "tools": self.tools, + "logprobs": self.logprobs, + "logit_bias": self.logit_bias, + "top_logprobs": self.top_logprobs, + "top_p": self.top_p, + } + params.update(kwargs) + + return client.chat.completions.create(**params) + + +class AzureChatOpenAI(BaseChatOpenAI): + """OpenAI chat model provided by Microsoft Azure""" + + azure_endpoint: str = Param( + help=( + "HTTPS endpoint for the Azure OpenAI model. The azure_endpoint, " + "azure_deployment, and api_version parameters are used to construct " + "the full URL for the Azure OpenAI model." + ) + ) + azure_deployment: str = Param(help="Azure deployment name", required=True) + api_version: str = Param(help="Azure model version", required=True) + azure_ad_token: Optional[str] = Param(None, help="Azure AD token") + azure_ad_token_provider: Optional[str] = Param(None, help="Azure AD token provider") + + @Param.auto(depends_on=["azure_ad_token_provider"]) + def azure_ad_token_provider_(self): + if isinstance(self.azure_ad_token_provider, str): + return import_dotted_string(self.azure_ad_token_provider, safe=False) + + def prepare_client(self, async_version: bool = False): + """Get the OpenAI client + + Args: + async_version (bool): Whether to get the async version of the client + """ + params = { + "azure_endpoint": self.azure_endpoint, + "api_version": self.api_version, + "api_key": self.api_key, + "azure_ad_token": self.azure_ad_token, + "azure_ad_token_provider": self.azure_ad_token_provider_, + "timeout": self.timeout, + "max_retries": self.max_retries_, + } + if async_version: + from openai import AsyncAzureOpenAI + + return AsyncAzureOpenAI(**params) + + from openai import AzureOpenAI + + return AzureOpenAI(**params) + + def openai_response(self, client, **kwargs): + """Get the openai response""" + params = { + "model": self.azure_deployment, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "n": self.n, + "stop": self.stop, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "tool_choice": self.tool_choice, + "tools": self.tools, + "logprobs": self.logprobs, + "logit_bias": self.logit_bias, + "top_logprobs": self.top_logprobs, + "top_p": self.top_p, + } + params.update(kwargs) + + return client.chat.completions.create(**params) diff --git a/libs/kotaemon/kotaemon/llms/cot.py b/libs/kotaemon/kotaemon/llms/cot.py index 7eaf5d1..a52f9bd 100644 --- a/libs/kotaemon/kotaemon/llms/cot.py +++ b/libs/kotaemon/kotaemon/llms/cot.py @@ -5,7 +5,7 @@ from theflow import Function, Node, Param from kotaemon.base import BaseComponent, Document -from .chats import AzureChatOpenAI +from .chats import LCAzureChatOpenAI from .completions import LLM from .prompts import BasePromptComponent @@ -25,7 +25,7 @@ class Thought(BaseComponent): >> from kotaemon.pipelines.cot import Thought >> thought = Thought( prompt="How to {action} {object}?", - llm=AzureChatOpenAI(...), + llm=LCAzureChatOpenAI(...), post_process=lambda string: {"tutorial": string}, ) >> output = thought(action="install", object="python") @@ -42,7 +42,7 @@ class Thought(BaseComponent): This `Thought` allows chaining sequentially with the + operator. For example: ```python - >> llm = AzureChatOpenAI(...) + >> llm = LCAzureChatOpenAI(...) >> thought1 = Thought( prompt="Word {word} in {language} is ", llm=llm, @@ -73,7 +73,7 @@ class Thought(BaseComponent): " component is executed" ) ) - llm: LLM = Node(AzureChatOpenAI, help="The LLM model to execute the input prompt") + llm: LLM = Node(LCAzureChatOpenAI, help="The LLM model to execute the input prompt") post_process: Function = Node( help=( "The function post-processor that post-processes LLM output prediction ." @@ -117,7 +117,7 @@ class ManualSequentialChainOfThought(BaseComponent): ```pycon >>> from kotaemon.pipelines.cot import Thought, ManualSequentialChainOfThought - >>> llm = AzureChatOpenAI(...) + >>> llm = LCAzureChatOpenAI(...) >>> thought1 = Thought( >>> prompt="Word {word} in {language} is ", >>> post_process=lambda string: {"translated": string}, diff --git a/libs/kotaemon/kotaemon/llms/linear.py b/libs/kotaemon/kotaemon/llms/linear.py index ac8605a..4c61597 100644 --- a/libs/kotaemon/kotaemon/llms/linear.py +++ b/libs/kotaemon/kotaemon/llms/linear.py @@ -22,12 +22,12 @@ class SimpleLinearPipeline(BaseComponent): Example Usage: ```python - from kotaemon.llms import AzureChatOpenAI, BasePromptComponent + from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent def identity(x): return x - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", @@ -89,13 +89,13 @@ class GatedLinearPipeline(SimpleLinearPipeline): Usage: ```{.py3 title="Example Usage"} - from kotaemon.llms import AzureChatOpenAI, BasePromptComponent + from kotaemon.llms import LCAzureChatOpenAI, BasePromptComponent from kotaemon.parsers import RegexExtractor def identity(x): return x - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="your openai api base", openai_api_key="your openai api key", openai_api_version="your openai api version", diff --git a/libs/kotaemon/pyproject.toml b/libs/kotaemon/pyproject.toml index 73c3e8a..a02337e 100644 --- a/libs/kotaemon/pyproject.toml +++ b/libs/kotaemon/pyproject.toml @@ -11,7 +11,7 @@ packages.find.exclude = ["tests*", "env*"] # metadata and dependencies [project] name = "kotaemon" -version = "0.3.8" +version = "0.3.9" requires-python = ">= 3.10" description = "Kotaemon core library for AI development." dependencies = [ diff --git a/libs/kotaemon/tests/test_agent.py b/libs/kotaemon/tests/test_agent.py index 0cc65fa..d489af9 100644 --- a/libs/kotaemon/tests/test_agent.py +++ b/libs/kotaemon/tests/test_agent.py @@ -13,7 +13,7 @@ from kotaemon.agents import ( RewooAgent, WikipediaTool, ) -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI FINAL_RESPONSE_TEXT = "Final Answer: Hello Cinnamon AI!" REWOO_VALID_PLAN = ( @@ -112,7 +112,7 @@ _openai_chat_completion_responses_react_langchain_tool = [ @pytest.fixture def llm(): - return AzureChatOpenAI( + return LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", diff --git a/libs/kotaemon/tests/test_composite.py b/libs/kotaemon/tests/test_composite.py index 464a456..38e79bd 100644 --- a/libs/kotaemon/tests/test_composite.py +++ b/libs/kotaemon/tests/test_composite.py @@ -4,10 +4,10 @@ import pytest from openai.types.chat.chat_completion import ChatCompletion from kotaemon.llms import ( - AzureChatOpenAI, BasePromptComponent, GatedBranchingPipeline, GatedLinearPipeline, + LCAzureChatOpenAI, SimpleBranchingPipeline, SimpleLinearPipeline, ) @@ -40,7 +40,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj( @pytest.fixture def mock_llm(): - return AzureChatOpenAI( + return LCAzureChatOpenAI( azure_endpoint="OPENAI_API_BASE", openai_api_key="OPENAI_API_KEY", openai_api_version="OPENAI_API_VERSION", diff --git a/libs/kotaemon/tests/test_cot.py b/libs/kotaemon/tests/test_cot.py index aef8a69..5fd1344 100644 --- a/libs/kotaemon/tests/test_cot.py +++ b/libs/kotaemon/tests/test_cot.py @@ -2,7 +2,7 @@ from unittest.mock import patch from openai.types.chat.chat_completion import ChatCompletion -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI from kotaemon.llms.cot import ManualSequentialChainOfThought, Thought _openai_chat_completion_response = [ @@ -38,7 +38,7 @@ _openai_chat_completion_response = [ side_effect=_openai_chat_completion_response, ) def test_cot_plus_operator(openai_completion): - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", @@ -70,7 +70,7 @@ def test_cot_plus_operator(openai_completion): side_effect=_openai_chat_completion_response, ) def test_cot_manual(openai_completion): - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", @@ -100,7 +100,7 @@ def test_cot_manual(openai_completion): side_effect=_openai_chat_completion_response, ) def test_cot_with_termination_callback(openai_completion): - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", diff --git a/libs/kotaemon/tests/test_llms_chat_models.py b/libs/kotaemon/tests/test_llms_chat_models.py index a6a2a24..3758f76 100644 --- a/libs/kotaemon/tests/test_llms_chat_models.py +++ b/libs/kotaemon/tests/test_llms_chat_models.py @@ -4,7 +4,7 @@ from unittest.mock import patch import pytest from kotaemon.base.schema import AIMessage, HumanMessage, LLMInterface, SystemMessage -from kotaemon.llms import AzureChatOpenAI, LlamaCppChat +from kotaemon.llms import LCAzureChatOpenAI, LlamaCppChat try: from langchain_openai import AzureChatOpenAI as AzureChatOpenAILC @@ -43,7 +43,7 @@ _openai_chat_completion_response = ChatCompletion.parse_obj( side_effect=lambda *args, **kwargs: _openai_chat_completion_response, ) def test_azureopenai_model(openai_completion): - model = AzureChatOpenAI( + model = LCAzureChatOpenAI( azure_endpoint="https://test.openai.azure.com/", openai_api_key="some-key", openai_api_version="2023-03-15-preview", diff --git a/libs/kotaemon/tests/test_reranking.py b/libs/kotaemon/tests/test_reranking.py index d4f7be8..ee37d3c 100644 --- a/libs/kotaemon/tests/test_reranking.py +++ b/libs/kotaemon/tests/test_reranking.py @@ -5,7 +5,7 @@ from openai.types.chat.chat_completion import ChatCompletion from kotaemon.base import Document from kotaemon.indices.rankings import LLMReranking -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI _openai_chat_completion_responses = [ ChatCompletion.parse_obj( @@ -41,7 +41,7 @@ _openai_chat_completion_responses = [ @pytest.fixture def llm(): - return AzureChatOpenAI( + return LCAzureChatOpenAI( azure_endpoint="https://dummy.openai.azure.com/", openai_api_key="dummy", openai_api_version="2023-03-15-preview", diff --git a/libs/ktem/flowsettings.py b/libs/ktem/flowsettings.py index 33ba88f..2e26cff 100644 --- a/libs/ktem/flowsettings.py +++ b/libs/ktem/flowsettings.py @@ -40,16 +40,15 @@ if config("AZURE_OPENAI_API_KEY", default="") and config( ): if config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""): KH_LLMS["azure"] = { - "def": { + "spec": { "__type__": "kotaemon.llms.AzureChatOpenAI", "temperature": 0, "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""), - "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""), + "api_key": config("AZURE_OPENAI_API_KEY", default=""), "api_version": config("OPENAI_API_VERSION", default="") or "2024-02-15-preview", - "deployment_name": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""), - "request_timeout": 10, - "stream": False, + "azure_deployment": config("AZURE_OPENAI_CHAT_DEPLOYMENT", default=""), + "timeout": 20, }, "default": False, "accuracy": 5, @@ -57,7 +56,7 @@ if config("AZURE_OPENAI_API_KEY", default="") and config( } if config("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT", default=""): KH_EMBEDDINGS["azure"] = { - "def": { + "spec": { "__type__": "kotaemon.embeddings.AzureOpenAIEmbeddings", "azure_endpoint": config("AZURE_OPENAI_ENDPOINT", default=""), "openai_api_key": config("AZURE_OPENAI_API_KEY", default=""), @@ -164,5 +163,11 @@ KH_INDICES = [ "name": "File", "config": {}, "index_type": "ktem.index.file.FileIndex", - } + }, + { + "id": 2, + "name": "Sample", + "config": {}, + "index_type": "ktem.index.file.FileIndex", + }, ] diff --git a/libs/ktem/ktem/components.py b/libs/ktem/ktem/components.py index 6cfb2e3..182cb91 100644 --- a/libs/ktem/ktem/components.py +++ b/libs/ktem/ktem/components.py @@ -3,6 +3,7 @@ import logging from functools import cache from pathlib import Path +from typing import Optional from theflow.settings import settings from theflow.utils.modules import deserialize @@ -48,7 +49,7 @@ class ModelPool: self._default: list[str] = [] for name, model in conf.items(): - self._models[name] = deserialize(model["def"], safe=False) + self._models[name] = deserialize(model["spec"], safe=False) if model.get("default", False): self._default.append(name) @@ -58,11 +59,27 @@ class ModelPool: self._cost = list(sorted(conf, key=lambda x: conf[x].get("cost", float("inf")))) def __getitem__(self, key: str) -> BaseComponent: + """Get model by name""" return self._models[key] def __setitem__(self, key: str, value: BaseComponent): + """Set model by name""" self._models[key] = value + def __delitem__(self, key: str): + """Delete model by name""" + del self._models[key] + + def __contains__(self, key: str) -> bool: + """Check if model exists""" + return key in self._models + + def get( + self, key: str, default: Optional[BaseComponent] = None + ) -> Optional[BaseComponent]: + """Get model by name with default value""" + return self._models.get(key, default) + def settings(self) -> dict: """Present model pools option for gradio""" return { @@ -169,4 +186,3 @@ llms = ModelPool("LLMs", settings.KH_LLMS) embeddings = ModelPool("Embeddings", settings.KH_EMBEDDINGS) reasonings: dict = {} tools = ModelPool("Tools", {}) -indices = ModelPool("Indices", {}) diff --git a/libs/ktem/ktem/index/file/pipelines.py b/libs/ktem/ktem/index/file/pipelines.py index b63d89c..13036f3 100644 --- a/libs/ktem/ktem/index/file/pipelines.py +++ b/libs/ktem/ktem/index/file/pipelines.py @@ -157,10 +157,10 @@ class DocumentRetrievalPipeline(BaseFileIndexRetriever): @classmethod def get_user_settings(cls) -> dict: - from ktem.components import llms + from ktem.llms.manager import llms try: - reranking_llm = llms.get_lowest_cost_name() + reranking_llm = llms.get_default_name() reranking_llm_choices = list(llms.options().keys()) except Exception as e: logger.error(e) diff --git a/libs/ktem/ktem/llms/__init__.py b/libs/ktem/ktem/llms/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/libs/ktem/ktem/llms/db.py b/libs/ktem/ktem/llms/db.py new file mode 100644 index 0000000..628ebb7 --- /dev/null +++ b/libs/ktem/ktem/llms/db.py @@ -0,0 +1,36 @@ +from typing import Type + +from ktem.db.engine import engine +from sqlalchemy import JSON, Boolean, Column, String +from sqlalchemy.orm import DeclarativeBase +from theflow.settings import settings as flowsettings +from theflow.utils.modules import import_dotted_string + + +class Base(DeclarativeBase): + pass + + +class BaseLLMTable(Base): + """Base table to store language model""" + + __abstract__ = True + + name = Column(String, primary_key=True, unique=True) + spec = Column(JSON, default={}) + default = Column(Boolean, default=False) + + +_base_llm: Type[BaseLLMTable] = ( + import_dotted_string(flowsettings.KH_TABLE_LLM, safe=False) + if hasattr(flowsettings, "KH_TABLE_LLM") + else BaseLLMTable +) + + +class LLMTable(_base_llm): # type: ignore + __tablename__ = "llm_table" + + +if not getattr(flowsettings, "KH_ENABLE_ALEMBIC", False): + LLMTable.metadata.create_all(engine) diff --git a/libs/ktem/ktem/llms/manager.py b/libs/ktem/ktem/llms/manager.py new file mode 100644 index 0000000..f9ad763 --- /dev/null +++ b/libs/ktem/ktem/llms/manager.py @@ -0,0 +1,191 @@ +from typing import Optional, Type + +from sqlalchemy import select +from sqlalchemy.orm import Session +from theflow.settings import settings as flowsettings +from theflow.utils.modules import deserialize + +from kotaemon.base import BaseComponent + +from .db import LLMTable, engine + + +class LLMManager: + """Represent a pool of models""" + + def __init__(self): + self._models: dict[str, BaseComponent] = {} + self._info: dict[str, dict] = {} + self._default: str = "" + self._vendors: list[Type] = [] + + if hasattr(flowsettings, "KH_LLMS"): + for name, model in flowsettings.KH_LLMS.items(): + with Session(engine) as session: + stmt = select(LLMTable).where(LLMTable.name == name) + result = session.execute(stmt) + if not result.first(): + item = LLMTable( + name=name, + spec=model["spec"], + default=model.get("default", False), + ) + session.add(item) + session.commit() + + self.load() + self.load_vendors() + + def load(self): + """Load the model pool from database""" + self._models, self._info, self._defaut = {}, {}, "" + with Session(engine) as session: + stmt = select(LLMTable) + items = session.execute(stmt) + + for (item,) in items: + self._models[item.name] = deserialize(item.spec, safe=False) + self._info[item.name] = { + "name": item.name, + "spec": item.spec, + "default": item.default, + } + if item.default: + self._default = item.name + + def load_vendors(self): + from kotaemon.llms import ( + AzureChatOpenAI, + ChatOpenAI, + EndpointChatLLM, + LlamaCppChat, + ) + + self._vendors = [ChatOpenAI, AzureChatOpenAI, LlamaCppChat, EndpointChatLLM] + + def __getitem__(self, key: str) -> BaseComponent: + """Get model by name""" + return self._models[key] + + def __contains__(self, key: str) -> bool: + """Check if model exists""" + return key in self._models + + def get( + self, key: str, default: Optional[BaseComponent] = None + ) -> Optional[BaseComponent]: + """Get model by name with default value""" + return self._models.get(key, default) + + def settings(self) -> dict: + """Present model pools option for gradio""" + return { + "label": "LLM", + "choices": list(self._models.keys()), + "value": self.get_default_name(), + } + + def options(self) -> dict: + """Present a dict of models""" + return self._models + + def get_random_name(self) -> str: + """Get the name of random model + + Returns: + str: random model name in the pool + """ + import random + + if not self._models: + raise ValueError("No models in pool") + + return random.choice(list(self._models.keys())) + + def get_default_name(self) -> str: + """Get the name of default model + + In case there is no default model, choose random model from pool. In + case there are multiple default models, choose random from them. + + Returns: + str: model name + """ + if not self._models: + raise ValueError("No models in pool") + + if not self._default: + return self.get_random_name() + + return self._default + + def get_random(self) -> BaseComponent: + """Get random model""" + return self._models[self.get_random_name()] + + def get_default(self) -> BaseComponent: + """Get default model + + In case there is no default model, choose random model from pool. In + case there are multiple default models, choose random from them. + + Returns: + BaseComponent: model + """ + return self._models[self.get_default_name()] + + def info(self) -> dict: + """List all models""" + return self._info + + def add(self, name: str, spec: dict, default: bool): + """Add a new model to the pool""" + try: + with Session(engine) as session: + item = LLMTable(name=name, spec=spec, default=default) + session.add(item) + session.commit() + except Exception as e: + raise ValueError(f"Failed to add model {name}: {e}") + + self.load() + + def delete(self, name: str): + """Delete a model from the pool""" + try: + with Session(engine) as session: + item = session.query(LLMTable).filter_by(name=name).first() + session.delete(item) + session.commit() + except Exception as e: + raise ValueError(f"Failed to delete model {name}: {e}") + + self.load() + + def update(self, name: str, spec: dict, default: bool): + """Update a model in the pool""" + try: + with Session(engine) as session: + + if default: + # turn all models to non-default + session.query(LLMTable).update({"default": False}) + session.commit() + + item = session.query(LLMTable).filter_by(name=name).first() + if not item: + raise ValueError(f"Model {name} not found") + item.spec = spec + item.default = default + session.commit() + except Exception as e: + raise ValueError(f"Failed to update model {name}: {e}") + + self.load() + + def vendors(self) -> dict: + """Return list of vendors""" + return {vendor.__qualname__: vendor for vendor in self._vendors} + + +llms = LLMManager() diff --git a/libs/ktem/ktem/llms/ui.py b/libs/ktem/ktem/llms/ui.py new file mode 100644 index 0000000..dd8f2bd --- /dev/null +++ b/libs/ktem/ktem/llms/ui.py @@ -0,0 +1,318 @@ +from copy import deepcopy + +import gradio as gr +import pandas as pd +import yaml +from ktem.app import BasePage + +from .manager import llms + + +def format_description(cls): + params = cls.describe()["params"] + params_lines = ["| Name | Type | Description |", "| --- | --- | --- |"] + for key, value in params.items(): + if isinstance(value["auto_callback"], str): + continue + params_lines.append(f"| {key} | {value['type']} | {value['help']} |") + return f"{cls.__doc__}\n\n" + "\n".join(params_lines) + + +class LLMManagement(BasePage): + def __init__(self, app): + self._app = app + self.spec_desc_default = ( + "# Spec description\n\nSelect an LLM to view the spec description." + ) + self.on_building_ui() + + def on_building_ui(self): + with gr.Tab(label="View"): + self.llm_list = gr.DataFrame( + headers=["name", "vendor", "default"], + interactive=False, + ) + + with gr.Column(visible=False) as self._selected_panel: + self.selected_llm_name = gr.Textbox(value="", visible=False) + with gr.Row(): + with gr.Column(): + self.edit_default = gr.Checkbox( + label="Set default", + info=( + "Set this LLM as default. If no default is set, a " + "random LLM will be used." + ), + ) + self.edit_spec = gr.Textbox( + label="Specification", + info="Specification of the LLM in YAML format", + lines=10, + ) + + with gr.Row(visible=False) as self._selected_panel_btn: + with gr.Column(): + self.btn_edit_save = gr.Button("Save", min_width=10) + with gr.Column(): + self.btn_delete = gr.Button("Delete", min_width=10) + with gr.Row(): + self.btn_delete_yes = gr.Button( + "Confirm delete", + variant="primary", + visible=False, + min_width=10, + ) + self.btn_delete_no = gr.Button( + "Cancel", visible=False, min_width=10 + ) + with gr.Column(): + self.btn_close = gr.Button("Close", min_width=10) + + with gr.Column(): + self.edit_spec_desc = gr.Markdown("# Spec description") + + with gr.Tab(label="Add"): + with gr.Row(): + with gr.Column(scale=2): + self.name = gr.Textbox( + label="LLM name", + info=( + "Must be unique. The name will be used to identify the LLM." + ), + ) + self.llm_choices = gr.Dropdown( + label="LLM vendors", + info=( + "Choose the vendor for the LLM. Each vendor has different " + "specification." + ), + ) + self.spec = gr.Textbox( + label="Specification", + info="Specification of the LLM in YAML format", + ) + self.default = gr.Checkbox( + label="Set default", + info=( + "Set this LLM as default. This default LLM will be used " + "by default across the application." + ), + ) + self.btn_new = gr.Button("Create LLM") + + with gr.Column(scale=3): + self.spec_desc = gr.Markdown(self.spec_desc_default) + + def _on_app_created(self): + """Called when the app is created""" + self._app.app.load( + self.list_llms, + inputs=None, + outputs=[self.llm_list], + ) + self._app.app.load( + lambda: gr.update(choices=list(llms.vendors().keys())), + outputs=[self.llm_choices], + ) + + def on_llm_vendor_change(self, vendor): + vendor = llms.vendors()[vendor] + + required: dict = {} + desc = vendor.describe() + for key, value in desc["params"].items(): + if value.get("required", False): + required[key] = None + + return yaml.dump(required), format_description(vendor) + + def on_register_events(self): + self.llm_choices.select( + self.on_llm_vendor_change, + inputs=[self.llm_choices], + outputs=[self.spec, self.spec_desc], + ) + self.btn_new.click( + self.create_llm, + inputs=[self.name, self.llm_choices, self.spec, self.default], + outputs=None, + ).then(self.list_llms, inputs=None, outputs=[self.llm_list],).then( + lambda: ("", None, "", False, self.spec_desc_default), + outputs=[ + self.name, + self.llm_choices, + self.spec, + self.default, + self.spec_desc, + ], + ) + self.llm_list.select( + self.select_llm, + inputs=self.llm_list, + outputs=[self.selected_llm_name], + show_progress="hidden", + ) + self.selected_llm_name.change( + self.on_selected_llm_change, + inputs=[self.selected_llm_name], + outputs=[ + self._selected_panel, + self._selected_panel_btn, + # delete section + self.btn_delete, + self.btn_delete_yes, + self.btn_delete_no, + # edit section + self.edit_spec, + self.edit_spec_desc, + self.edit_default, + ], + show_progress="hidden", + ) + self.btn_delete.click( + self.on_btn_delete_click, + inputs=None, + outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no], + show_progress="hidden", + ) + self.btn_delete_yes.click( + self.delete_llm, + inputs=[self.selected_llm_name], + outputs=[self.selected_llm_name], + show_progress="hidden", + ).then( + self.list_llms, + inputs=None, + outputs=[self.llm_list], + ) + self.btn_delete_no.click( + lambda: ( + gr.update(visible=True), + gr.update(visible=False), + gr.update(visible=False), + ), + inputs=None, + outputs=[self.btn_delete, self.btn_delete_yes, self.btn_delete_no], + show_progress="hidden", + ) + self.btn_edit_save.click( + self.save_llm, + inputs=[ + self.selected_llm_name, + self.edit_default, + self.edit_spec, + ], + show_progress="hidden", + ).then( + self.list_llms, + inputs=None, + outputs=[self.llm_list], + ) + self.btn_close.click( + lambda: "", + outputs=[self.selected_llm_name], + ) + + def create_llm(self, name, choices, spec, default): + try: + spec = yaml.safe_load(spec) + spec["__type__"] = ( + llms.vendors()[choices].__module__ + + "." + + llms.vendors()[choices].__qualname__ + ) + + llms.add(name, spec=spec, default=default) + gr.Info(f"LLM {name} created successfully") + except Exception as e: + gr.Error(f"Failed to create LLM {name}: {e}") + + def list_llms(self): + """List the LLMs""" + items = [] + for item in llms.info().values(): + record = {} + record["name"] = item["name"] + record["vendor"] = item["spec"].get("__type__", "-").split(".")[-1] + record["default"] = item["default"] + items.append(record) + + if items: + llm_list = pd.DataFrame.from_records(items) + else: + llm_list = pd.DataFrame.from_records( + [{"name": "-", "vendor": "-", "default": "-"}] + ) + + return llm_list + + def select_llm(self, llm_list, ev: gr.SelectData): + if ev.value == "-" and ev.index[0] == 0: + gr.Info("No LLM is loaded. Please add LLM first") + return "" + + if not ev.selected: + return "" + + return llm_list["name"][ev.index[0]] + + def on_selected_llm_change(self, selected_llm_name): + if selected_llm_name == "": + _selected_panel = gr.update(visible=False) + _selected_panel_btn = gr.update(visible=False) + btn_delete = gr.update(visible=True) + btn_delete_yes = gr.update(visible=False) + btn_delete_no = gr.update(visible=False) + edit_spec = gr.update(value="") + edit_spec_desc = gr.update(value="") + edit_default = gr.update(value=False) + else: + _selected_panel = gr.update(visible=True) + _selected_panel_btn = gr.update(visible=True) + btn_delete = gr.update(visible=True) + btn_delete_yes = gr.update(visible=False) + btn_delete_no = gr.update(visible=False) + + info = deepcopy(llms.info()[selected_llm_name]) + vendor_str = info["spec"].pop("__type__", "-").split(".")[-1] + vendor = llms.vendors()[vendor_str] + + edit_spec = yaml.dump(info["spec"]) + edit_spec_desc = format_description(vendor) + edit_default = info["default"] + + return ( + _selected_panel, + _selected_panel_btn, + btn_delete, + btn_delete_yes, + btn_delete_no, + edit_spec, + edit_spec_desc, + edit_default, + ) + + def on_btn_delete_click(self): + btn_delete = gr.update(visible=False) + btn_delete_yes = gr.update(visible=True) + btn_delete_no = gr.update(visible=True) + + return btn_delete, btn_delete_yes, btn_delete_no + + def save_llm(self, selected_llm_name, default, spec): + try: + spec = yaml.safe_load(spec) + spec["__type__"] = llms.info()[selected_llm_name]["spec"]["__type__"] + llms.update(selected_llm_name, spec=spec, default=default) + gr.Info(f"LLM {selected_llm_name} saved successfully") + except Exception as e: + gr.Error(f"Failed to save LLM {selected_llm_name}: {e}") + + def delete_llm(self, selected_llm_name): + try: + llms.delete(selected_llm_name) + except Exception as e: + gr.Error(f"Failed to delete LLM {selected_llm_name}: {e}") + return selected_llm_name + + return "" diff --git a/libs/ktem/ktem/pages/admin/__init__.py b/libs/ktem/ktem/pages/admin/__init__.py index 1cc58c7..b32d816 100644 --- a/libs/ktem/ktem/pages/admin/__init__.py +++ b/libs/ktem/ktem/pages/admin/__init__.py @@ -1,6 +1,7 @@ import gradio as gr from ktem.app import BasePage from ktem.db.models import User, engine +from ktem.llms.ui import LLMManagement from sqlmodel import Session, select from .user import UserManagement @@ -16,6 +17,9 @@ class AdminPage(BasePage): with gr.Tab("User Management", visible=False) as self.user_management_tab: self.user_management = UserManagement(self._app) + with gr.Tab("LLM Management") as self.llm_management_tab: + self.llm_management = LLMManagement(self._app) + def on_subscribe_public_events(self): if self._app.f_user_management: self._app.subscribe_event( diff --git a/libs/ktem/ktem/pages/chat/__init__.py b/libs/ktem/ktem/pages/chat/__init__.py index 04c78eb..3f01571 100644 --- a/libs/ktem/ktem/pages/chat/__init__.py +++ b/libs/ktem/ktem/pages/chat/__init__.py @@ -9,6 +9,8 @@ from ktem.db.models import Conversation, engine from sqlmodel import Session, select from theflow.settings import settings as flowsettings +from kotaemon.base import Document + from .chat_panel import ChatPanel from .chat_suggestion import ChatSuggestion from .common import STATE @@ -189,6 +191,7 @@ class ChatPage(BasePage): self.chat_control.conversation_rn, self.chat_panel.chatbot, self.info_panel, + self.chat_state, ] + self._indices_input, show_progress="hidden", @@ -220,6 +223,7 @@ class ChatPage(BasePage): self.chat_control.conversation_rn, self.chat_panel.chatbot, self.info_panel, + self.chat_state, ] + self._indices_input, show_progress="hidden", @@ -392,7 +396,7 @@ class ChatPage(BasePage): return pipeline, reasoning_state - async def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): + def chat_fn(self, conversation_id, chat_history, settings, state, *selecteds): """Chat function""" chat_input = chat_history[-1][0] chat_history = chat_history[:-1] @@ -403,52 +407,43 @@ class ChatPage(BasePage): pipeline, reasoning_state = self.create_pipeline(settings, state, *selecteds) pipeline.set_output_queue(queue) - asyncio.create_task(pipeline(chat_input, conversation_id, chat_history)) text, refs = "", "" - - len_ref = -1 # for logging purpose msg_placeholder = getattr( flowsettings, "KH_CHAT_MSG_PLACEHOLDER", "Thinking ..." ) - print(msg_placeholder) - while True: - try: - response = queue.get_nowait() - except Exception: - state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] - yield chat_history + [ - (chat_input, text or msg_placeholder) - ], refs, state + yield chat_history + [(chat_input, text or msg_placeholder)], refs, state + + len_ref = -1 # for logging purpose + + for response in pipeline.stream(chat_input, conversation_id, chat_history): + + if not isinstance(response, Document): continue - if response is None: - queue.task_done() - print("Chat completed") - break + if response.channel is None: + continue - if "output" in response: - if response["output"] is None: + if response.channel == "chat": + if response.content is None: text = "" else: - text += response["output"] + text += response.content - if "evidence" in response: - if response["evidence"] is None: + if response.channel == "info": + if response.content is None: refs = "" else: - refs += response["evidence"] + refs += response.content if len(refs) > len_ref: print(f"Len refs: {len(refs)}") len_ref = len(refs) - state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] - yield chat_history + [(chat_input, text)], refs, state + state[pipeline.get_info()["id"]] = reasoning_state["pipeline"] + yield chat_history + [(chat_input, text or msg_placeholder)], refs, state - async def regen_fn( - self, conversation_id, chat_history, settings, state, *selecteds - ): + def regen_fn(self, conversation_id, chat_history, settings, state, *selecteds): """Regen function""" if not chat_history: gr.Warning("Empty chat") @@ -456,12 +451,11 @@ class ChatPage(BasePage): return state["app"]["regen"] = True - async for chat, refs, state in self.chat_fn( + for chat, refs, state in self.chat_fn( conversation_id, chat_history, settings, state, *selecteds ): new_state = deepcopy(state) new_state["app"]["regen"] = False yield chat, refs, new_state - else: - state["app"]["regen"] = False - yield chat_history, "", state + + state["app"]["regen"] = False diff --git a/libs/ktem/ktem/reasoning/simple.py b/libs/ktem/ktem/reasoning/simple.py index 082c20f..3397250 100644 --- a/libs/ktem/ktem/reasoning/simple.py +++ b/libs/ktem/ktem/reasoning/simple.py @@ -4,10 +4,10 @@ import logging import re from collections import defaultdict from functools import partial +from typing import Generator import tiktoken -from ktem.components import llms -from theflow.settings import settings as flowsettings +from ktem.llms.manager import llms from kotaemon.base import ( BaseComponent, @@ -190,10 +190,10 @@ class AnswerWithContextPipeline(BaseComponent): lang: the language of the answer. Currently support English and Japanese """ - llm: ChatLLM = Node(default_callback=lambda _: llms.get_highest_accuracy()) - vlm_endpoint: str = flowsettings.KH_VLM_ENDPOINT + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) + vlm_endpoint: str = "" citation_pipeline: CitationPipeline = Node( - default_callback=lambda _: CitationPipeline(llm=llms.get_lowest_cost()) + default_callback=lambda _: CitationPipeline(llm=llms.get_default()) ) qa_template: str = DEFAULT_QA_TEXT_PROMPT @@ -297,13 +297,95 @@ class AnswerWithContextPipeline(BaseComponent): return answer + def stream( # type: ignore + self, question: str, evidence: str, evidence_mode: int = 0, **kwargs + ) -> Generator[Document, None, Document]: + """Answer the question based on the evidence -def extract_evidence_images(self, evidence: str): - """Util function to extract and isolate images from context/evidence""" - image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" - matches = re.findall(image_pattern, evidence) - context = re.sub(image_pattern, "", evidence) - return context, matches + In addition to the question and the evidence, this method also take into + account evidence_mode. The evidence_mode tells which kind of evidence is. + The kind of evidence affects: + 1. How the evidence is represented. + 2. The prompt to generate the answer. + + By default, the evidence_mode is 0, which means the evidence is plain text with + no particular semantic representation. The evidence_mode can be: + 1. "table": There will be HTML markup telling that there is a table + within the evidence. + 2. "chatbot": There will be HTML markup telling that there is a chatbot. + This chatbot is a scenario, extracted from an Excel file, where each + row corresponds to an interaction. + + Args: + question: the original question posed by user + evidence: the text that contain relevant information to answer the question + (determined by retrieval pipeline) + evidence_mode: the mode of evidence, 0 for text, 1 for table, 2 for chatbot + """ + if evidence_mode == EVIDENCE_MODE_TEXT: + prompt_template = PromptTemplate(self.qa_template) + elif evidence_mode == EVIDENCE_MODE_TABLE: + prompt_template = PromptTemplate(self.qa_table_template) + elif evidence_mode == EVIDENCE_MODE_FIGURE: + prompt_template = PromptTemplate(self.qa_figure_template) + else: + prompt_template = PromptTemplate(self.qa_chatbot_template) + + images = [] + if evidence_mode == EVIDENCE_MODE_FIGURE: + # isolate image from evidence + evidence, images = self.extract_evidence_images(evidence) + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + else: + prompt = prompt_template.populate( + context=evidence, + question=question, + lang=self.lang, + ) + + output = "" + if evidence_mode == EVIDENCE_MODE_FIGURE: + for text in stream_gpt4v(self.vlm_endpoint, images, prompt, max_tokens=768): + output += text + yield Document(channel="chat", content=text) + else: + messages = [] + if self.system_prompt: + messages.append(SystemMessage(content=self.system_prompt)) + messages.append(HumanMessage(content=prompt)) + + try: + # try streaming first + print("Trying LLM streaming") + for text in self.llm.stream(messages): + output += text.text + yield Document(channel="chat", content=text.text) + except NotImplementedError: + print("Streaming is not supported, falling back to normal processing") + output = self.llm(messages).text + yield Document(channel="chat", content=output) + + # retrieve the citation + citation = None + if evidence and self.enable_citation: + citation = self.citation_pipeline.invoke( + context=evidence, question=question + ) + + answer = Document(text=output, metadata={"citation": citation}) + + return answer + + def extract_evidence_images(self, evidence: str): + """Util function to extract and isolate images from context/evidence""" + image_pattern = r"src='(data:image\/[^;]+;base64[^']+)'" + matches = re.findall(image_pattern, evidence) + context = re.sub(image_pattern, "", evidence) + return context, matches class RewriteQuestionPipeline(BaseComponent): @@ -315,27 +397,19 @@ class RewriteQuestionPipeline(BaseComponent): lang: the language of the answer. Currently support English and Japanese """ - llm: ChatLLM = Node(default_callback=lambda _: llms.get_lowest_cost()) + llm: ChatLLM = Node(default_callback=lambda _: llms.get_default()) rewrite_template: str = DEFAULT_REWRITE_PROMPT lang: str = "English" - async def run(self, question: str) -> Document: # type: ignore + def run(self, question: str) -> Document: # type: ignore prompt_template = PromptTemplate(self.rewrite_template) prompt = prompt_template.populate(question=question, lang=self.lang) messages = [ SystemMessage(content="You are a helpful assistant"), HumanMessage(content=prompt), ] - output = "" - for text in self.llm(messages): - if "content" in text: - output += text[1] - self.report_output({"chat_input": text[1]}) - break - await asyncio.sleep(0) - - return Document(text=output) + return self.llm(messages) class FullQAPipeline(BaseReasoning): @@ -351,7 +425,7 @@ class FullQAPipeline(BaseReasoning): rewrite_pipeline: RewriteQuestionPipeline = RewriteQuestionPipeline.withx() use_rewrite: bool = False - async def run( # type: ignore + async def ainvoke( # type: ignore self, message: str, conv_id: str, history: list, **kwargs # type: ignore ) -> Document: # type: ignore import markdown @@ -482,6 +556,132 @@ class FullQAPipeline(BaseReasoning): self.report_output(None) return answer + def stream( # type: ignore + self, message: str, conv_id: str, history: list, **kwargs # type: ignore + ) -> Generator[Document, None, Document]: + import markdown + + docs = [] + doc_ids = [] + if self.use_rewrite: + message = self.rewrite_pipeline(question=message).text + + for retriever in self.retrievers: + for doc in retriever(text=message): + if doc.doc_id not in doc_ids: + docs.append(doc) + doc_ids.append(doc.doc_id) + for doc in docs: + # TODO: a better approach to show the information + text = markdown.markdown( + doc.text, extensions=["markdown.extensions.tables"] + ) + yield Document( + content=( + "
" + f"{doc.metadata['file_name']}" + f"{text}" + "

" + ), + channel="info", + ) + + evidence_mode, evidence = self.evidence_pipeline(docs).content + answer = yield from self.answering_pipeline.stream( + question=message, + history=history, + evidence=evidence, + evidence_mode=evidence_mode, + conv_id=conv_id, + **kwargs, + ) + + # prepare citation + spans = defaultdict(list) + if answer.metadata["citation"] is not None: + for fact_with_evidence in answer.metadata["citation"].answer: + for quote in fact_with_evidence.substring_quote: + for doc in docs: + start_idx = doc.text.find(quote) + if start_idx == -1: + continue + + end_idx = start_idx + len(quote) + + current_idx = start_idx + if "|" not in doc.text[start_idx:end_idx]: + spans[doc.doc_id].append( + {"start": start_idx, "end": end_idx} + ) + else: + while doc.text[current_idx:end_idx].find("|") != -1: + match_idx = doc.text[current_idx:end_idx].find("|") + spans[doc.doc_id].append( + { + "start": current_idx, + "end": current_idx + match_idx, + } + ) + current_idx += match_idx + 2 + if current_idx > end_idx: + break + break + + id2docs = {doc.doc_id: doc for doc in docs} + lack_evidence = True + not_detected = set(id2docs.keys()) - set(spans.keys()) + yield Document(channel="info", content=None) + for id, ss in spans.items(): + if not ss: + not_detected.add(id) + continue + ss = sorted(ss, key=lambda x: x["start"]) + text = id2docs[id].text[: ss[0]["start"]] + for idx, span in enumerate(ss): + text += ( + "" + id2docs[id].text[span["start"] : span["end"]] + "" + ) + if idx < len(ss) - 1: + text += id2docs[id].text[span["end"] : ss[idx + 1]["start"]] + text += id2docs[id].text[ss[-1]["end"] :] + text_out = markdown.markdown( + text, extensions=["markdown.extensions.tables"] + ) + yield Document( + content=( + "
" + f"{id2docs[id].metadata['file_name']}" + f"{text_out}" + "

" + ), + channel="info", + ) + lack_evidence = False + + if lack_evidence: + yield Document(channel="info", content="No evidence found.\n") + + if not_detected: + yield Document( + channel="info", + content="Retrieved segments without matching evidence:\n", + ) + for id in list(not_detected): + text_out = markdown.markdown( + id2docs[id].text, extensions=["markdown.extensions.tables"] + ) + yield Document( + content=( + "
" + f"{id2docs[id].metadata['file_name']}" + f"{text_out}" + "

" + ), + channel="info", + ) + + return answer + @classmethod def get_pipeline(cls, settings, states, retrievers): """Get the reasoning pipeline @@ -493,12 +693,9 @@ class FullQAPipeline(BaseReasoning): _id = cls.get_info()["id"] pipeline = FullQAPipeline(retrievers=retrievers) - pipeline.answering_pipeline.llm = llms[ - settings[f"reasoning.options.{_id}.main_llm"] - ] - pipeline.answering_pipeline.citation_pipeline.llm = llms[ - settings[f"reasoning.options.{_id}.citation_llm"] - ] + pipeline.answering_pipeline.llm = llms.get_default() + pipeline.answering_pipeline.citation_pipeline.llm = llms.get_default() + pipeline.answering_pipeline.enable_citation = settings[ f"reasoning.options.{_id}.highlight_citation" ] @@ -512,7 +709,7 @@ class FullQAPipeline(BaseReasoning): f"reasoning.options.{_id}.qa_prompt" ] pipeline.use_rewrite = states.get("app", {}).get("regen", False) - pipeline.rewrite_pipeline.llm = llms.get_lowest_cost() + pipeline.rewrite_pipeline.llm = llms.get_default() pipeline.rewrite_pipeline.lang = {"en": "English", "ja": "Japanese"}.get( settings["reasoning.lang"], "English" ) @@ -520,38 +717,12 @@ class FullQAPipeline(BaseReasoning): @classmethod def get_user_settings(cls) -> dict: - from ktem.components import llms - - try: - citation_llm = llms.get_lowest_cost_name() - citation_llm_choices = list(llms.options().keys()) - main_llm = llms.get_highest_accuracy_name() - main_llm_choices = list(llms.options().keys()) - except Exception as e: - logger.error(e) - citation_llm = None - citation_llm_choices = [] - main_llm = None - main_llm_choices = [] - return { "highlight_citation": { "name": "Highlight Citation", "value": False, "component": "checkbox", }, - "citation_llm": { - "name": "LLM for citation", - "value": citation_llm, - "component": "dropdown", - "choices": citation_llm_choices, - }, - "main_llm": { - "name": "LLM for main generation", - "value": main_llm, - "component": "dropdown", - "choices": main_llm_choices, - }, "system_prompt": { "name": "System Prompt", "value": "This is a question answering system", diff --git a/libs/ktem/ktem_tests/test_qa.py b/libs/ktem/ktem_tests/test_qa.py index a3993ee..80ee68b 100644 --- a/libs/ktem/ktem_tests/test_qa.py +++ b/libs/ktem/ktem_tests/test_qa.py @@ -7,7 +7,7 @@ from index import ReaderIndexingPipeline from openai.resources.embeddings import Embeddings from openai.types.chat.chat_completion import ChatCompletion -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI with open(Path(__file__).parent / "resources" / "embedding_openai.json") as f: openai_embedding = json.load(f) @@ -61,7 +61,7 @@ def test_ingest_pipeline(patch, mock_openai_embedding, tmp_path): assert len(results) == 1 # create llm - llm = AzureChatOpenAI( + llm = LCAzureChatOpenAI( openai_api_base="https://test.openai.azure.com/", openai_api_key="some-key", openai_api_version="2023-03-15-preview", diff --git a/libs/ktem/launch.py b/libs/ktem/launch.py index 2ac7a1a..1f436c5 100644 --- a/libs/ktem/launch.py +++ b/libs/ktem/launch.py @@ -2,4 +2,4 @@ from ktem.main import App app = App() demo = app.make() -demo.queue().launch(favicon_path=app._favicon, inbrowser=True) +demo.queue().launch(favicon_path=app._favicon) diff --git a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py index 1739ca8..db2fa0b 100644 --- a/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py +++ b/templates/project-default/{{cookiecutter.project_name}}/{{cookiecutter.project_name}}/pipeline.py @@ -5,7 +5,7 @@ from kotaemon.base import BaseComponent, Document, LLMInterface, Node, Param, la from kotaemon.contribs.promptui.logs import ResultLog from kotaemon.embeddings import AzureOpenAIEmbeddings from kotaemon.indices import VectorIndexing, VectorRetrieval -from kotaemon.llms import AzureChatOpenAI +from kotaemon.llms import LCAzureChatOpenAI from kotaemon.storages import ChromaVectorStore, SimpleFileDocumentStore @@ -34,7 +34,7 @@ class QuestionAnsweringPipeline(BaseComponent): ] retrieval_top_k: int = 1 - llm: AzureChatOpenAI = AzureChatOpenAI.withx( + llm: LCAzureChatOpenAI = LCAzureChatOpenAI.withx( azure_endpoint="https://bleh-dummy-2.openai.azure.com/", openai_api_key=os.environ.get("OPENAI_API_KEY", "default-key"), openai_api_version="2023-03-15-preview",